예제 #1
0
    def setUp(self) -> None:

        self.r = Discrete.binary(0.2, 'rain')
        self.s = Discrete.binary(0.1, 'sprinkler')
        self.j__r = Conditional.from_probs(
            {
                (1, 1): 1,
                (1, 0): 0.2,
                (0, 1): 0,
                (0, 0): 0.8
            },
            joint_variables='jack',
            conditional_variables='rain')
        self.t__r_s = Conditional.from_probs(
            {
                (1, 1, 0): 1,
                (1, 1, 1): 1,
                (1, 0, 1): 0.9,
                (1, 0, 0): 0,
                (0, 1, 0): 0,
                (0, 1, 1): 0,
                (0, 0, 1): 0.1,
                (0, 0, 0): 1
            },
            joint_variables='tracey',
            conditional_variables=['rain', 'sprinkler'])
예제 #2
0
    def test_binary_from_probs(self):

        c__a_b__1 = Conditional.from_probs(
            data={
                (1, 0, 0): 0.1,
                (1, 0, 1): 0.99,
                (1, 1, 0): 0.8,
                (1, 1, 1): 0.25,
                (0, 0, 0): 1 - 0.1,
                (0, 0, 1): 1 - 0.99,
                (0, 1, 0): 1 - 0.8,
                (0, 1, 1): 1 - 0.25,
            },
            joint_variables='C',
            conditional_variables=['A', 'B']
        ).data
        c__a_b__2 = Conditional.binary_from_probs(
            data={
                (0, 0): 0.1,
                (0, 1): 0.99,
                (1, 0): 0.8,
                (1, 1): 0.25,
            },
            joint_variable='C',
            conditional_variables=['A', 'B']
        ).data
        self.assertTrue(c__a_b__1.equals(c__a_b__2))
예제 #3
0
    def test_init__with_names(self):

        language_probs = self.get_language_probs()
        language_probs.index.name = 'language'
        language_probs.columns.name = 'country'
        language__given__country = Conditional(data=language_probs)
        self.check_conditional(language__given__country)
예제 #4
0
    def test_given_one_variable(self):

        expected = Conditional.binary_from_probs(
            data={
                0: 0,
                1: 1,
            },
            joint_variable='A_xor_B',
            conditional_variables='B').data
        xor = Conditional.binary_from_probs(data={
            (0, 0): 0,
            (0, 1): 1,
            (1, 0): 1,
            (1, 1): 0,
        },
                                            joint_variable='A_xor_B',
                                            conditional_variables=['A', 'B'])
        actual = xor.given(A=0).data
        self.assertTrue(expected.equals(actual))
예제 #5
0
    def test_given_all_variables(self):

        expected = Discrete.binary(0, 'A_xor_B').data
        xor = Conditional.binary_from_probs(data={
            (0, 0): 0,
            (0, 1): 1,
            (1, 0): 1,
            (1, 1): 0,
        },
                                            joint_variable='A_xor_B',
                                            conditional_variables=['A', 'B'])
        actual = xor.given(A=1, B=1).data
        self.assertTrue(expected.equals(actual))
예제 #6
0
    def test_init__with_vars(self):

        language_probs = self.get_language_probs()
        language__given__country = Conditional(
            data=language_probs,
            joint_variables='language',
            conditional_variables='country',
            states={
                'language': self.languages,
                'country': self.countries
            }
        )
        self.check_conditional(language__given__country)
예제 #7
0
    def test_from_probs_with_dict(self):

        probs = {
            ('English', 'England'): 0.95,
            ('English', 'Scotland'): 0.7,
            ('English', 'Wales'): 0.6,
            ('Scottish', 'England'): 0.04,
            ('Scottish', 'Scotland'): 0.3,
            ('Scottish', 'Wales'): 0.0,
            ('Welsh', 'England'): 0.01,
            ('Welsh', 'Scotland'): 0.0,
            ('Welsh', 'Wales'): 0.4,
        }
        language__given__country = Conditional.from_probs(
            data=probs,
            joint_variables='language',
            conditional_variables='country'
        )
        self.check_conditional(language__given__country)
예제 #8
0
    def test_from_probs_with_series(self):

        probs = Series({
            ('England', 'English'): 0.95,
            ('Scotland', 'English'): 0.7,
            ('Wales', 'English'): 0.6,
            ('England', 'Scottish'): 0.04,
            ('Scotland', 'Scottish'): 0.3,
            ('Wales', 'Scottish'): 0.0,
            ('England', 'Welsh'): 0.01,
            ('Scotland', 'Welsh'): 0.0,
            ('Wales', 'Welsh'): 0.4,
        })
        probs.index.names = ['country', 'language']
        language__given__country = Conditional.from_probs(
            data=probs,
            joint_variables='language',
            conditional_variables='country'
        )
        self.check_conditional(language__given__country)
        print(language__given__country)