def setUp(self) -> None: self.r = Discrete.binary(0.2, 'rain') self.s = Discrete.binary(0.1, 'sprinkler') self.j__r = Conditional.from_probs( { (1, 1): 1, (1, 0): 0.2, (0, 1): 0, (0, 0): 0.8 }, joint_variables='jack', conditional_variables='rain') self.t__r_s = Conditional.from_probs( { (1, 1, 0): 1, (1, 1, 1): 1, (1, 0, 1): 0.9, (1, 0, 0): 0, (0, 1, 0): 0, (0, 1, 1): 0, (0, 0, 1): 0.1, (0, 0, 0): 1 }, joint_variables='tracey', conditional_variables=['rain', 'sprinkler'])
def test_binary_from_probs(self): c__a_b__1 = Conditional.from_probs( data={ (1, 0, 0): 0.1, (1, 0, 1): 0.99, (1, 1, 0): 0.8, (1, 1, 1): 0.25, (0, 0, 0): 1 - 0.1, (0, 0, 1): 1 - 0.99, (0, 1, 0): 1 - 0.8, (0, 1, 1): 1 - 0.25, }, joint_variables='C', conditional_variables=['A', 'B'] ).data c__a_b__2 = Conditional.binary_from_probs( data={ (0, 0): 0.1, (0, 1): 0.99, (1, 0): 0.8, (1, 1): 0.25, }, joint_variable='C', conditional_variables=['A', 'B'] ).data self.assertTrue(c__a_b__1.equals(c__a_b__2))
def test_init__with_names(self): language_probs = self.get_language_probs() language_probs.index.name = 'language' language_probs.columns.name = 'country' language__given__country = Conditional(data=language_probs) self.check_conditional(language__given__country)
def test_given_one_variable(self): expected = Conditional.binary_from_probs( data={ 0: 0, 1: 1, }, joint_variable='A_xor_B', conditional_variables='B').data xor = Conditional.binary_from_probs(data={ (0, 0): 0, (0, 1): 1, (1, 0): 1, (1, 1): 0, }, joint_variable='A_xor_B', conditional_variables=['A', 'B']) actual = xor.given(A=0).data self.assertTrue(expected.equals(actual))
def test_given_all_variables(self): expected = Discrete.binary(0, 'A_xor_B').data xor = Conditional.binary_from_probs(data={ (0, 0): 0, (0, 1): 1, (1, 0): 1, (1, 1): 0, }, joint_variable='A_xor_B', conditional_variables=['A', 'B']) actual = xor.given(A=1, B=1).data self.assertTrue(expected.equals(actual))
def test_init__with_vars(self): language_probs = self.get_language_probs() language__given__country = Conditional( data=language_probs, joint_variables='language', conditional_variables='country', states={ 'language': self.languages, 'country': self.countries } ) self.check_conditional(language__given__country)
def test_from_probs_with_dict(self): probs = { ('English', 'England'): 0.95, ('English', 'Scotland'): 0.7, ('English', 'Wales'): 0.6, ('Scottish', 'England'): 0.04, ('Scottish', 'Scotland'): 0.3, ('Scottish', 'Wales'): 0.0, ('Welsh', 'England'): 0.01, ('Welsh', 'Scotland'): 0.0, ('Welsh', 'Wales'): 0.4, } language__given__country = Conditional.from_probs( data=probs, joint_variables='language', conditional_variables='country' ) self.check_conditional(language__given__country)
def test_from_probs_with_series(self): probs = Series({ ('England', 'English'): 0.95, ('Scotland', 'English'): 0.7, ('Wales', 'English'): 0.6, ('England', 'Scottish'): 0.04, ('Scotland', 'Scottish'): 0.3, ('Wales', 'Scottish'): 0.0, ('England', 'Welsh'): 0.01, ('Scotland', 'Welsh'): 0.0, ('Wales', 'Welsh'): 0.4, }) probs.index.names = ['country', 'language'] language__given__country = Conditional.from_probs( data=probs, joint_variables='language', conditional_variables='country' ) self.check_conditional(language__given__country) print(language__given__country)