コード例 #1
0
    def test_intercept(self, distribution, n_categories, noise_scale):
        graph = StructureModel()
        graph.add_node("A")

        data_noint = generate_categorical_dataframe(
            graph,
            100000,
            distribution,
            noise_scale=noise_scale,
            n_categories=n_categories,
            seed=10,
            intercept=False,
        )
        data_intercept = generate_categorical_dataframe(
            graph,
            100000,
            distribution,
            noise_scale=noise_scale,
            n_categories=n_categories,
            seed=10,
            intercept=True,
        )

        # NOTE: as n_categories increases, the probability that at least one category with
        # intercept=True will be the same as intercept=False -> 1.0
        num_similar = np.isclose(data_intercept.mean(axis=0),
                                 data_noint.mean(axis=0),
                                 atol=0.05,
                                 rtol=0).sum()
        assert num_similar < n_categories / 2
コード例 #2
0
    def test_intercept(self, distribution, noise_scale):
        graph = StructureModel()
        graph.add_node("123")

        data_noint = generate_continuous_data(
            graph,
            n_samples=100000,
            distribution=distribution,
            noise_scale=noise_scale,
            seed=10,
            intercept=False,
        )
        data_intercept = generate_continuous_data(
            graph,
            n_samples=100000,
            distribution=distribution,
            noise_scale=noise_scale,
            seed=10,
            intercept=True,
        )
        assert not np.isclose(data_noint[:, 0].mean(),
                              data_intercept[:, 0].mean())
        assert np.isclose(data_noint[:, 0].std(),
                          data_intercept[:, 0].std(),
                          rtol=0.01)
コード例 #3
0
ファイル: test_plotting.py プロジェクト: zeta1999/causalnex
    def test_all_nodes_exist(self):
        """Both connected and unconnected nodes should exist"""
        sm = StructureModel([("a", "b")])
        sm.add_node("c")
        a_graph = plot_structure(sm)

        assert all(node in a_graph.nodes() for node in ["a", "b", "c"])
コード例 #4
0
 def test_get_indices_empty_iterator(self, schema):
     graph = StructureModel()
     # add node without parents:
     graph.add_node(10)
     mapper = VariableFeatureMapper(schema)
     x = mapper.get_indices(graph.predecessors(10))
     assert len(x) == 0
     assert isinstance(x, list)
コード例 #5
0
    def test_isolates(self):
        """Should return an isolated node"""

        nodes = [1, 3, 5, 2, 7]
        sm = StructureModel()
        sm.add_nodes_from(nodes)
        subgraph = sm.get_target_subgraph(1)
        expected_graph = StructureModel()
        expected_graph.add_node(1)

        assert set(subgraph.nodes) == set(expected_graph.nodes)
        assert set(subgraph.edges) == set(expected_graph.edges)
コード例 #6
0
    def test_isolates(self):
        """Should return an isolated node"""

        nodes = [1, 3, 5, 2, 7]
        sm = StructureModel()
        sm.add_nodes_from(nodes)
        blanket = sm.get_markov_blanket(1)

        expected_graph = StructureModel()
        expected_graph.add_node(1)

        assert set(blanket.nodes) == set(expected_graph.nodes)
        assert set(blanket.edges) == set(expected_graph.edges)
コード例 #7
0
    def test_intercept(self, distribution, noise_scale):
        graph = StructureModel()
        graph.add_node("123")

        data_noint = generate_binary_data(
            graph,
            100000,
            distribution,
            noise_scale=noise_scale,
            seed=10,
            intercept=False,
        )
        data_intercept = generate_binary_data(
            graph,
            100000,
            distribution,
            noise_scale=noise_scale,
            seed=10,
            intercept=True,
        )
        assert not np.isclose(data_noint[:, 0].mean(),
                              data_intercept[:, 0].mean())
コード例 #8
0
ファイル: agents.py プロジェクト: sa-and/interventional_RL
class CausalAgent(ABC):
    """
    The base class for all agents which maintain an epistemic causal graph about their environment.
    """
    var_names: Union[int, List[str]]
    causal_model: StructureModel
    collected_data: dict
    actions: List[Any]
    state_repeats: int

    def __init__(self, vars: Union[int, List[str]],
                 causal_graph: StructureModel = None,
                 env_type: str = 'Switchboard',
                 state_repeats: int = 1,
                 allow_interventions: bool = True):
        self.allow_interventions = allow_interventions
        self.env_type = env_type
        if type(vars) == int:
            self.var_names = ['x' + str(i) for i in range(vars)]
        else:
            self.var_names = vars

        # initialize causal model
        if causal_graph:
            self.causal_model = causal_graph
        else:
            self.causal_model = StructureModel()
            [self.causal_model.add_node(name) for name in self.var_names]
            self.reset_causal_model()

        # initialize the storages for observational and interventional data.
        self.collected_data = {}

        self.action_space = None
        self.observation_space = None
        self.actions = []
        self.current_action = None
        self.state_repeats = state_repeats

    # --------------------------- Methods for maintaining the causal structure of the agent ---------------------------
    def set_causal_model(self, causal_model: StructureModel):
        self.causal_model = causal_model

    def reset_causal_model(self, mode: str = 'random'):
        """
        Sets the causal graph of the agent to either a graph with random edges or without edges at all.
        :param mode: 'random' or 'empty'
        """
        all_pairs = [(v[0], v[1]) for v in permutations(self.var_names, 2)]

        if mode == 'random':
            random.shuffle(all_pairs)
            for p in all_pairs:
                self.update_model(p, random.choice([0, 1, 2]))

        elif mode == 'empty':
            # delete all edges
            for p in all_pairs:
                self.update_model(p, 0)
        else:
            raise TypeError('No reset defined for mode ' + mode)

    def update_model(self, edge: Tuple[str, str],
                     manipulation: int,
                     allow_disconnecting: bool = True,
                     allow_cycles: bool = True) -> bool:
        """
        Updates model according to action and returns the success of the operation. Reversing and removing an edge that
        doesn't exists has no effect. Adding an edge which already exists has no effect.

        :param edge: The edge to be manipulated. e.g. (X0, X1)
        :param manipulation: 0 = remove edge, 1 = add edge, 2 = reverse edge
        :param allow_disconnecting: If true, manipulations which disconnect the causal graph can be executed.
        :param allow_cycles: If true, manipulations which result in a cycle can be executed.
        :return: True if the manipulation was successful. False if it wasn't or it was illegal according to
        'allow_disconnecting' or 'allow_cycles'.
        """

        if manipulation == 0:  # remove edge if exists
            if self.causal_model.has_edge(edge[0], edge[1]):
                self.causal_model.remove_edge(edge[0], edge[1])
                removed_edge = (edge[0], edge[1])
            else:
                return False

            # disconnected graph
            if not allow_disconnecting and nx.number_weakly_connected_components(self.causal_model) > 1:
                self.causal_model.add_edge(removed_edge[0], removed_edge[1])
                return False

        elif manipulation == 1:  # add edge
            if not self.causal_model.has_edge(edge[0], edge[1]):  # only add edge if not already there
                self.causal_model.add_edge(edge[0], edge[1])
            else:
                return False

            if not nx.is_directed_acyclic_graph(self.causal_model) and not allow_cycles:  # check if became cyclic
                self.causal_model.remove_edge(edge[0], edge[1])
                return False

        elif manipulation == 2:  # reverse edge
            if self.causal_model.has_edge(edge[0], edge[1]):
                self.causal_model.remove_edge(edge[0], edge[1])
                self.causal_model.add_edge(edge[1], edge[0])
                added_edge = (edge[1], edge[0])
            else:
                return False

            if not nx.is_directed_acyclic_graph(self.causal_model) and not allow_cycles:  # check if became cyclic
                self.causal_model.remove_edge(added_edge[0], added_edge[1])
                self.causal_model.add_edge(added_edge[1], added_edge[0])
                return False

        return True

    def display_causal_model(self) -> NoReturn:
        fig, ax = plt.subplots()
        nx.draw_circular(self.causal_model, ax=ax, with_labels=True)
        fig.show()

    def get_graph_state(self) -> List[float]:
        """
        Get a list of values that represents the state of an edge in the causal graph for each possible graph.
        The edges are ordered in lexographical order.

        Example:
        In a 3 node graph there are the potential edges: 0-1, 0-2, 1-2. The list [0, 0.5, 1] represents the
        graph 0x1, 0->2, 1<-2, where x means that there is no edge.

        :return: state of the graph
        """
        graph_state = []
        possible_edges = [e for e in combinations(self.var_names, 2)]
        for e in possible_edges:
            if self.causal_model.has_edge(e[0], e[1]):
                graph_state.append(0.5)
            elif self.causal_model.has_edge(e[1], e[0]):
                graph_state.append(1.0)
            else:
                graph_state.append(0.0)
        return graph_state

    # ------------------------- Methods for evaluation of causal inference power of the agent -------------------------
    def get_est_avg_causal_effect(self, query: str, intervened_var: str, end: Any, initial: Any) -> float:
        """
        Estimates the average causal effect of the 'intervened_var' on the 'query' variable when changing the value of
        the 'intervened_var' from 'initial' to 'any' in the collected data of the agent. This, effectively, approximates
        E(P(query|do(intervened_var=end)) - E(P(query|do(intervened_var=initial))

        :param query: The variable on which the effect is measured.
        :param intervened_var: The variable on which the intervention is performed.
        :param end: end value of the intervened var
        :param initial: start value of the intervened var
        :return: estimated average causal effect
        """
        exp_val1 = self._get_expected_value(self.get_est_postint_distrib(query, intervened_var, end))
        exp_val2 = self._get_expected_value(self.get_est_postint_distrib(query, intervened_var, initial))

        return exp_val1 - exp_val2

    def compare_edge_to_data(self, edge: Tuple[str, str], threshold: float = 0.0) -> bool:
        """
        Checks whether the edge of the model corresponds to an actual causal effect in the collected interventional
        data. So for a given edge A -> B it checks whether P(B|do(A=x)) != P(B|do(A=y)) holds in the collected
        interventional data set.
        Note that this direct effect suggested in the model could actually be an indirect one in the data. Still, this
        methods returns true in this case.

        :param edge: The edge to be checked. E.g. ('x1', 'x2')
        :param threshold: The value from which on the effect is assumed to be present.
        :return: Whether the causal edge is backed up by the interventional data.
        """
        assert self.causal_model.has_edge(edge[0], edge[1]), 'The given edge is not part of the current model.'

        if '(' + edge[0] + ',True)' in self.collected_data and '(' + edge[0] + ',False)' in self.collected_data:
            est_causal_effect = self.get_est_avg_causal_effect(edge[1], edge[0], True, False)
            if abs(est_causal_effect) >= threshold:
                return True
            else:
                return False

        elif '(' + edge[0] + ',0.0)' in self.collected_data and '(' + edge[0] + ',5.0)' in self.collected_data:
            est_causal_effect = self.get_est_avg_causal_effect(edge[1], edge[0], 0.0, 5.0)
            if abs(est_causal_effect) >= threshold:
                return True
            else:
                return False

        return True

    def has_wrong_edges(self, threshold: float = 0.0) -> int:
        """
        Determines how many edges in the current causal model do not have a causal effect in the interventional
        data set that is bigger than the given threshold.
        :param threshold:
        :return: number of 'wrong' edges
        """
        count = 0
        for e in self.causal_model.edges:
            if not self.compare_edge_to_data(e, threshold):
                count += 1
        return count

    def reverse_wrong_edges(self, threshold: float = 0.0) -> NoReturn:
        """
        Checks all edges whether they are the wrong way around and reverses those that are.

        :param threshold:
        """
        wrong_edges = []
        for e in self.causal_model.edges:
            if not self.compare_edge_to_data(e, threshold):
                wrong_edges.append(e)

        for e in wrong_edges:
            self.update_model(e, 2)
    
    def edge_is_missing(self, edge: Tuple[str, str], threshold: float = 0.0) -> bool:
        """
        Checks whether for the given edge (which is not part of the model) there is a causal effect in the collected
        interventional data. If true, there should be a directed path between edge[0] and edge[1] in the model but
        there is none. This means that along the path from edge[0] to edge[1] at least one edge is missing.
        
        :param edge: Edge to check
        :param threshold: Value from which on the effect is to be considered an actual effect.
        :return: Whether, according to the interventional data there should be an edge but is none.
        """
        if edge in self.causal_model.edges:
            return False

        elif nx.has_path(self.causal_model, edge[0], edge[1]):
            return False
        
        else:
            if '(' + edge[0] + ',True)' in self.collected_data and '(' + edge[0] + ',False)' in self.collected_data:
                effect = self.get_est_avg_causal_effect(edge[1], edge[0], True, False)
                return abs(effect) >= threshold
            elif '(' + edge[0] + ',0.0)' in self.collected_data and '(' + edge[0] + ',5.0)' in self.collected_data:
                effect = self.get_est_avg_causal_effect(edge[1], edge[0], 5.0, 0.0)
                return abs(effect) >= threshold
            else:
                return True
        
    def has_missing_edges(self, threshold: float = 0.0) -> int:
        """
        Returns the maximal number of missing edges in the model according to the collected interventional
        data. The maximum number is returned because the exact number cannot be determined with an intervention
        on a single variable.

        :param threshold:
        :return:

        Example
        ---------
        Let the ground truth causal model be A -> B -> C and the causal model of the agent A   B -> C (missing
        edge between A and B). This method will return 2. This is because the edge between A and B induces an indirect
        effect of A on C which cannot be distilled from a direct effect that could be present from A to C.

        The collected interventional data with the intervention only on one variable cannot distinguish between
        A -> B -> C and A -> B -> C, hence a maximum of 2 edges are missing.
                        |         ^
                        - - - - - |
        Important: Once there is any path from A to C, no edge is considered to be missing.
        e.g. applied to the model A -> B -> C this method returns 0
        e.g. applied to the model A -> C <- B this method returns 1 as the edge A -> B is missing.

        """
        missing_edges = 0
        # check which causal relationships are missing in the graph
        for n in self.causal_model.nodes:
            # iterate over all nodes that do not already have an edge from n
            for nn in nx.non_neighbors(self.causal_model, n):
                current_edge = (str(n), str(nn))
                if self.edge_is_missing(current_edge, threshold):
                    missing_edges += 1
        
        return missing_edges

    def get_est_postint_distrib(self, query: str, intervened_var: str, val: Any) -> pd.Series:
        """Computes and returns P(Query | do(action))"""
        key = '('+intervened_var+','+str(val)+')'
        query_dataframe = self.collected_data[key]
        query_dataframe = query_dataframe.groupby(query).size()/len(query_dataframe)
        return query_dataframe.rename('P('+query+'|do'+key+')')

    def store_observation(self, obs: List[Any], intervened_var: Optional[str], val: Any) -> NoReturn:
        """
        Stores the observation of the environment in the appropriate dataframe.
        If no intervention is made (intervened_var=None), the observation is saved in the purely observational
        DataFrame.

        :param obs: Observation list of values
        :param intervened_var: On which var to intervene. Can be None
        :param val: The assigned value of the intervened variable
        """
        if intervened_var == None:
            key = '(None, None)'
        else:
            key = '(' + intervened_var + ',' + str(val) + ')'

        # put the observation in the right dataframe
        obs_dict = pd.DataFrame().append({self.var_names[i]: obs[i] for i in range(len(self.var_names))},
                                         ignore_index=True)

        if key in self.collected_data:
            self.collected_data[key] = self.collected_data[key].append(obs_dict, ignore_index=True)
        else:
            self.collected_data[key] = obs_dict

    def get_est_cond_distr(self, query: str, var: str, val: Any) -> DataFrame:
        """
        Compute and return P(query | var = val)
        """
        obs_data = self.collected_data['(None, None)']
        obs_data = obs_data[obs_data[var] == val]
        obs_data = obs_data.groupby(query).size() / len(obs_data)
        return obs_data.rename('P('+query+'|'+var+'='+str(val)+')')

    def graph_is_learned(self, threshold: float = 0.0) -> bool:
        """
        Determined if there are any edges in the epistemic graph of the agent which can't be found in it's collected
        interventional data and whether there are edges in the collected interventional data which are not part of
        the graph of the agent.
        :param threshold: effect threshold of average causal effect.
        :return: True if no edges are missing and no edges are wrong.
        """
        n_wrong_edges = self.has_wrong_edges(threshold)
        print('wrong edges: ', n_wrong_edges)
        n_missing_edges = self.has_missing_edges(threshold)
        print('missing edges: ', n_missing_edges)
        return n_wrong_edges == 0 and n_missing_edges == 0

    def is_legal_intervention(self, interv_var: str) -> bool:
        """
        Checks if performing an intervention disconnects the graph. If it does, it is not a legal intervention
        for the causalnex library.
        :param interv_var: variable to intervene on.
        :return: False if an intervention on 'interv_var' would disconnect the graph.
        """
        model = self.causal_model.copy()
        nodes = nx.nodes(model)
        for n in nodes:
            if model.has_edge(n, interv_var):
                model.remove_edge(n, interv_var)
        is_connected = nx.number_weakly_connected_components(model) <= 1
        return is_connected

    @staticmethod
    def _get_expected_value(distribution: pd.Series) -> float:
        if type(distribution.index[0] == bool):
            distribution = distribution.rename({True: 1, False: 0})
        return sum(distribution.index.values * distribution._values)

    # ---------------------------------------------- Abstract methods ----------------------------------------------
    @abstractmethod
    def get_action_from_actionspace_sample(self, sample: Any):
        raise NotImplementedError

    @abstractmethod
    def store_observation_per_action(self, obs: List[Any]):
        raise NotImplementedError

    @abstractmethod
    def update_model_per_action(self, action: Any):
        raise NotImplementedError