def test_mixed_conditions(self): self.assertTrue( series_are_equivalent(self.p_AB__C__D_1, condition(given(self.joint, D=1), 'C'))) self.assertTrue( series_are_equivalent(self.p_AB__C_2__D, condition(given(self.joint, C=2), 'D')))
def test_given_conditions(self): self.assertTrue( series_are_equivalent(self.p_ABC__D_1, given(self.joint, D=1))) for c in ['A', 'B', 'C']: kwargs = {c: 1} self.assertFalse( series_are_equivalent(self.p_ABC__D_1, given(self.joint, **kwargs))) self.assertTrue( series_are_equivalent(self.p_AB__C_1__D_2, given(self.joint, C=1, D=2))) for c1, c2 in product(self.vars, self.vars): if c1 == c2 or (c1 == 'C' and c2 == 'D'): continue kwargs = {c1: 1, c2: 2} self.assertFalse( series_are_equivalent(self.p_AB__C_1__D_2, given(self.joint, **kwargs))) self.assertTrue( series_are_equivalent(self.p_A__B_1__C_2__D_3, given(self.joint, B=1, C=2, D=3)))
def given(self, **given_vals) -> DiscreteDistribution: """ Return the Discrete Distribution at the given values of the conditional variables. :param given_vals: Values of conditional variables to create probability distribution from. """ # check input arguments given_val_keys = set(given_vals.keys()) has_all_conds = set( self._cond_vars).intersection(given_val_keys) == set( self._cond_vars) if not has_all_conds: raise ValueError( 'Must supply values for all conditioned variables to get to joint distribution.' ) # calculate probability data = given(self._data, **given_vals) return DiscreteDistribution(data, given_conditions=given_vals)
def given(self, **given_conditions) -> 'DiscreteDistribution': """ Condition on values of variables. :param given_conditions: Dict[{name}__{comparator}, value] for each conditioned variable. """ # check input variables names_comps = set(given_conditions.keys()) if not all([valid_name_comparator(name_comp, self._joints) for name_comp in names_comps]): raise ValueError('Given variables must be members of joint distribution.') # calculate conditional distribution data = given(self._data, **given_conditions) cond_values = {**self._given_conditions, **given_conditions} return DiscreteDistribution( data=data, given_var_names=self._given_vars + [cond_name(k, self.var_names) for k in names_comps], given_conditions=cond_values )
def test_chain_mixed_conditions(self): self.assertFalse( series_are_equivalent(condition(given(self.joint, C=1), 'D'), given(condition(self.joint, 'D'), C=1)))
def test_chain_given_conditions(self): self.assertTrue( series_are_equivalent(given(given(self.joint, C=1), D=2), given(given(self.joint, D=2), C=1)))