예제 #1
0
    def test_sort(self):
        product = dit.helpers.get_product_func(str)
        samplespace = ['00', '01', '12', '10']
        ss = SampleSpace(samplespace, product)
        assert_equal(list(ss), samplespace)
        indexes = [ss.index(i) for i in samplespace]
        assert_equal(indexes, list(range(len(samplespace))))

        ss.sort()
        assert_equal(list(ss), sorted(samplespace))
        indexes = [ss.index(i) for i in samplespace]
        assert_equal(indexes, [0, 1, 3, 2])
예제 #2
0
class TestSampleSpace(object):
    def setUp(self):
        product = dit.helpers.get_product_func(str)
        self.samplespace = ['00', '01', '10', '12']
        self.ss = SampleSpace(self.samplespace, product)

    def test_samplespace_auto(self):
        samplespace = ['00', '01', '10', '12']
        ss = SampleSpace(samplespace)
        assert_equal(list(ss), list(self.ss))

    def test_samplespace(self):
        assert_equal(list(self.ss), self.samplespace)
        assert_equal(len(self.ss), 4)
        assert_equal(self.ss.outcome_length(), 2)
        assert_true('00' in self.ss)
        assert_false('22' in self.ss)

    def test_marginalize(self):
        ss0 = self.ss.marginalize([1])
        assert_equal(list(ss0), ['0', '1'])

    def test_marginal(self):
        ss1 = self.ss.marginal([1])
        assert_equal(list(ss1), ['0', '1', '2'])

    def test_coalesce(self):
        ss2 = self.ss.coalesce([[0,1,1],[1,0]])
        ss2_ = [('000', '00'), ('011', '10'), ('100', '01'), ('122', '21')]
        assert_equal(list(ss2), ss2_)

    def test_sort(self):
        product = dit.helpers.get_product_func(str)
        samplespace = ['00', '01', '12', '10']
        ss = SampleSpace(samplespace, product)
        assert_equal(list(ss), samplespace)
        indexes = [ss.index(i) for i in samplespace]
        assert_equal(indexes, list(range(len(samplespace))))

        ss.sort()
        assert_equal(list(ss), sorted(samplespace))
        indexes = [ss.index(i) for i in samplespace]
        assert_equal(indexes, [0, 1, 3, 2])
예제 #3
0
class TestSampleSpace(object):
    def setup_class(self):
        product = dit.helpers.get_product_func(str)
        self.samplespace = ['00', '01', '10', '12']
        self.ss = SampleSpace(self.samplespace, product)

    def test_samplespace_auto(self):
        samplespace = ['00', '01', '10', '12']
        ss = SampleSpace(samplespace)
        assert list(ss) == list(self.ss)

    def test_samplespace(self):
        assert list(self.ss) == self.samplespace
        assert len(self.ss) == 4
        assert self.ss.outcome_length() == 2
        assert '00' in self.ss
        assert not '22' in self.ss

    def test_marginalize(self):
        ss0 = self.ss.marginalize([1])
        assert list(ss0) == ['0', '1']

    def test_marginal(self):
        ss1 = self.ss.marginal([1])
        assert list(ss1) == ['0', '1', '2']

    def test_coalesce(self):
        ss2 = self.ss.coalesce([[0, 1, 1], [1, 0]])
        ss2_ = [('000', '00'), ('011', '10'), ('100', '01'), ('122', '21')]
        assert list(ss2) == ss2_

    def test_sort(self):
        product = dit.helpers.get_product_func(str)
        samplespace = ['00', '01', '12', '10']
        ss = SampleSpace(samplespace, product)
        assert list(ss) == samplespace
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == list(range(len(samplespace)))

        ss.sort()
        assert list(ss) == sorted(samplespace)
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == [0, 1, 3, 2]
예제 #4
0
class TestSampleSpace(object):
    def setup_class(self):
        product = dit.helpers.get_product_func(str)
        self.samplespace = ['00', '01', '10', '12']
        self.ss = SampleSpace(self.samplespace, product)

    def test_samplespace_auto(self):
        samplespace = ['00', '01', '10', '12']
        ss = SampleSpace(samplespace)
        assert list(ss) == list(self.ss)

    def test_samplespace(self):
        assert list(self.ss) == self.samplespace
        assert len(self.ss) == 4
        assert self.ss.outcome_length() == 2
        assert '00' in self.ss
        assert not '22' in self.ss

    def test_marginalize(self):
        ss0 = self.ss.marginalize([1])
        assert list(ss0) == ['0', '1']

    def test_marginal(self):
        ss1 = self.ss.marginal([1])
        assert list(ss1) == ['0', '1', '2']

    def test_coalesce(self):
        ss2 = self.ss.coalesce([[0,1,1],[1,0]])
        ss2_ = [('000', '00'), ('011', '10'), ('100', '01'), ('122', '21')]
        assert list(ss2) == ss2_

    def test_sort(self):
        product = dit.helpers.get_product_func(str)
        samplespace = ['00', '01', '12', '10']
        ss = SampleSpace(samplespace, product)
        assert list(ss) == samplespace
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == list(range(len(samplespace)))

        ss.sort()
        assert list(ss) == sorted(samplespace)
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == [0, 1, 3, 2]
예제 #5
0
    def test_sort(self):
        product = dit.helpers.get_product_func(str)
        samplespace = ['00', '01', '12', '10']
        ss = SampleSpace(samplespace, product)
        assert list(ss) == samplespace
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == list(range(len(samplespace)))

        ss.sort()
        assert list(ss) == sorted(samplespace)
        indexes = [ss.index(i) for i in samplespace]
        assert indexes == [0, 1, 3, 2]
예제 #6
0
파일: prune_expand.py 프로젝트: vreuter/dit
def pruned_samplespace(d, sample_space=None):
    """
    Returns a new distribution with pruned sample space.

    The pruning is such that zero probability outcomes are removed.

    Parameters
    ----------
    d : distribution
        The distribution used to create the pruned distribution.

    sample_space : set
        A list of outcomes with zero probability that should be kept in the
        sample space. If `None`, then all outcomes with zero probability
        will be removed.

    Returns
    -------
    pd : distribution
        The distribution with a pruned sample space.

    """
    if sample_space is None:
        sample_space = []

    keep = set(sample_space)
    outcomes = []
    pmf = []
    for o, p in d.zipped(mode='atoms'):
        if not d.ops.is_null_exact(p) or o in keep:
            outcomes.append(o)
            pmf.append(p)

    if d.is_joint():
        sample_space = SampleSpace(outcomes)
    else:
        sample_space = ScalarSampleSpace(outcomes)
    pd = d.__class__(outcomes,
                     pmf,
                     sample_space=sample_space,
                     base=d.get_base())
    return pd
예제 #7
0
 def setUp(self):
     product = dit.helpers.get_product_func(str)
     self.samplespace = ['00', '01', '10', '12']
     self.ss = SampleSpace(self.samplespace, product)
예제 #8
0
 def test_samplespace_auto(self):
     samplespace = ['00', '01', '10', '12']
     ss = SampleSpace(samplespace)
     assert list(ss) == list(self.ss)
예제 #9
0
 def setup_class(self):
     product = dit.helpers.get_product_func(str)
     self.samplespace = ['00', '01', '10', '12']
     self.ss = SampleSpace(self.samplespace, product)