コード例 #1
0
ファイル: test_sampler.py プロジェクト: jzf2101/lda
def test_convergence_simple():
    N, V = 2, 10
    defn = model_definition(N, V)
    data = [
        np.array([5, 6]),
        np.array([0, 1, 2]),
    ]
    view = numpy_dataview(data)
    prng = rng()

    scores = []
    idmap = {}
    for i, (tables, dishes) in enumerate(permutations([2, 3])):
        latent = model.initialize(
            defn, view, prng,
            table_assignments=tables,
            dish_assignments=dishes)
        scores.append(
            latent.score_assignment() +
            latent.score_data(prng))
        idmap[(tables, dishes)] = i
    true_dist = scores_to_probs(scores)

    def kernel(latent):
        # mutates latent in place
        doc_model = model.bind(latent, data=view)
        kernels.assign2(doc_model, prng)
        for did in xrange(latent.nentities()):
            table_model = model.bind(latent, document=did)
            kernels.assign(table_model, prng)

    latent = model.initialize(defn, view, prng)

    skip = 10
    def sample_fn():
        for _ in xrange(skip):
            kernel(latent)
        table_assignments = latent.table_assignments()
        canon_table_assigments = tuple(
            map(tuple, map(permutation_canonical, table_assignments)))

        dish_maps = latent.dish_assignments()
        dish_assignments = []
        for dm, (ta, ca) in zip(dish_maps, zip(table_assignments, canon_table_assigments)):
            dish_assignment = []
            for t, c in zip(ta, ca):
                if c == len(dish_assignment):
                    dish_assignment.append(dm[t])
            dish_assignments.append(dish_assignment)

        canon_dish_assigments = tuple(
            map(tuple, map(permutation_canonical, dish_assignments)))

        return idmap[(canon_table_assigments, canon_dish_assigments)]

    assert_discrete_dist_approx(
        sample_fn, true_dist,
        ntries=100, nsamples=10000, kl_places=2)
コード例 #2
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_cant_serialize():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    prng = rng()
    s = initialize(defn, data, prng)
    s.serialize()
コード例 #3
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_cant_serialize():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    prng = rng()
    s = initialize(defn, data, prng)
    s.serialize()
コード例 #4
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_alpha_numeric():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    assert_equals(s.nentities(), len(docs))
    assert_equals(s.nwords(), 6)
コード例 #5
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_alpha_numeric():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    assert_equals(s.nentities(), len(docs))
    assert_equals(s.nwords(), 6)
コード例 #6
0
def test_simple():
    N, V = 10, 100
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = numpy_dataview(data)
    R = rng()
    s = initialize(defn, view, R)
    assert_equals(s.nentities(), len(data))
コード例 #7
0
ファイル: test_state.py プロジェクト: jzf2101/lda
def test_simple():
    N, V = 10, 100
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = numpy_dataview(data)
    R = rng()
    s = initialize(defn, view, R)
    assert_equals(s.nentities(), len(data))
コード例 #8
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_multi_dish_initialization():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng, initial_dishes=V)
    assert_true(s.ntopics() > 1)
コード例 #9
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_multi_dish_initialization():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng, initial_dishes=V)
    assert_true(s.ntopics() > 1)
コード例 #10
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_simple():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng)
    assert_equals(s.nentities(), len(data))
コード例 #11
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_simple():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng)
    assert_equals(s.nentities(), len(data))
コード例 #12
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_single_dish_initialization():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng, initial_dishes=1)
    assert_equals(s.ntopics(), 0) # Only dummy topic
コード例 #13
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_single_dish_initialization():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng, initial_dishes=1)
    assert_equals(s.ntopics(), 0)  # Only dummy topic
コード例 #14
0
ファイル: test_runner.py プロジェクト: zhongyunuestc/lda-2
def test_runner_specify_basic_kernel():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    latent = model.initialize(defn, view, prng)
    r = runner.runner(defn, view, latent, ["crf"])
    r.run(prng, 1)
コード例 #15
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_explicit_inception():
    """Initialize a new state using assignments from old

    Helps ensure that our assignment validation code is correct
    """
    N, V = 3, 7
    defn = model_definition(N, V)
    data = [[0, 1, 2, 3], [0, 1, 4], [0, 1, 5, 6]]

    table_assignments = [[1, 2, 1, 2], [1, 1, 1], [3, 3, 3, 1]]
    dish_assignments = [[0, 1, 2], [0, 3], [0, 1, 2, 1]]

    s = initialize(defn, data,
                   table_assignments=table_assignments,
                   dish_assignments=dish_assignments)
    s2 = initialize(defn, data,
                    table_assignments=s.table_assignments(),
                    dish_assignments=s.dish_assignments())
コード例 #16
0
ファイル: test_runner.py プロジェクト: zhongyunuestc/lda-2
def test_runner_simple():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    latent = model.initialize(defn, view, prng)
    r = runner.runner(defn, view, latent)
    r.run(prng, 1)
コード例 #17
0
ファイル: test_reuters.py プロジェクト: datamicroscopes/lda
 def test_lda_zero_iter(self):
     # compare to model with 0 iterations
     prng2 = rng(seed=54321)
     latent2 = model.initialize(self.defn, self.docs, prng2)
     assert latent2 is not None
     r2 = runner.runner(self.defn, self.docs, latent2)
     assert r2 is not None
     doc_topic2 = latent2.topic_distribution_by_document()
     assert doc_topic2 is not None
     assert latent2.perplexity() > self.latent.perplexity()
コード例 #18
0
ファイル: test_runner.py プロジェクト: jzf2101/lda
def test_runner_simple():
    N, V = 10, 100
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = numpy_dataview(data)
    prng = rng()
    latent = model.initialize(defn, view, prng)
    kc = runner.default_kernel_config(defn)
    r = runner.runner(defn, view, latent, kc)
    r.run(prng, 1)
コード例 #19
0
ファイル: test_state.py プロジェクト: mrG7/lda
def test_serialize_simple():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng)
    m = s.serialize()
    s2 = deserialize(defn, m)
    assert s2.__class__ == s.__class__
コード例 #20
0
ファイル: test_reuters.py プロジェクト: zhongyunuestc/lda-2
 def test_lda_zero_iter(self):
     # compare to model with 0 iterations
     prng2 = rng(seed=54321)
     latent2 = model.initialize(self.defn, self.docs, prng2)
     assert latent2 is not None
     r2 = runner.runner(self.defn, self.docs, latent2)
     assert r2 is not None
     doc_topic2 = latent2.topic_distribution_by_document()
     assert doc_topic2 is not None
     assert latent2.perplexity() > self.latent.perplexity()
コード例 #21
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_explicit_inception():
    """Initialize a new state using assignments from old

    Helps ensure that our assignment validation code is correct
    """
    N, V = 3, 7
    defn = model_definition(N, V)
    data = [[0, 1, 2, 3], [0, 1, 4], [0, 1, 5, 6]]

    table_assignments = [[1, 2, 1, 2], [1, 1, 1], [3, 3, 3, 1]]
    dish_assignments = [[0, 1, 2], [0, 3], [0, 1, 2, 1]]

    s = initialize(defn,
                   data,
                   table_assignments=table_assignments,
                   dish_assignments=dish_assignments)
    s2 = initialize(defn,
                    data,
                    table_assignments=s.table_assignments(),
                    dish_assignments=s.dish_assignments())
コード例 #22
0
ファイル: test_reuters.py プロジェクト: datamicroscopes/lda
    def setup_class(cls):
        cls._load_docs()
        cls.niters = 100 if os.environ.get('TRAVIS') else 2

        cls.defn = model_definition(cls.N, cls.V)
        cls.seed = 12345
        cls.prng = rng(seed=cls.seed)
        cls.latent = model.initialize(cls.defn, cls.docs, cls.prng)
        cls.r = runner.runner(cls.defn, cls.docs, cls.latent)
        cls.original_perplexity = cls.latent.perplexity()
        cls.r.run(cls.prng, cls.niters)
        cls.doc_topic = cls.latent.topic_distribution_by_document()
コード例 #23
0
ファイル: test_runner.py プロジェクト: zhongyunuestc/lda-2
def test_runner_specify_hp_kernels():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    latent = model.initialize(defn, view, prng)
    kernels = ['crf'] + \
        runner.second_dp_hp_kernel_config(defn) + \
        runner.base_dp_hp_kernel_config(defn)
    r = runner.runner(defn, view, latent, kernels)
    r.run(prng, 1)
コード例 #24
0
ファイル: test_reuters.py プロジェクト: zhongyunuestc/lda-2
    def setup_class(cls):
        cls._load_docs()
        cls.niters = 100 if os.environ.get('TRAVIS') else 2

        cls.defn = model_definition(cls.N, cls.V)
        cls.seed = 12345
        cls.prng = rng(seed=cls.seed)
        cls.latent = model.initialize(cls.defn, cls.docs, cls.prng)
        cls.r = runner.runner(cls.defn, cls.docs, cls.latent)
        cls.original_perplexity = cls.latent.perplexity()
        cls.r.run(cls.prng, cls.niters)
        cls.doc_topic = cls.latent.topic_distribution_by_document()
コード例 #25
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_serialize_simple():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    m = s.serialize()
    s2 = deserialize(defn, m)
    assert s2.__class__ == s.__class__
    assert all(word in "abcdef" for wd in s2.word_distribution_by_topic()
               for word in wd.keys())
    assert all(
        isinstance(word, str) for wd in s2.word_distribution_by_topic()
        for word in wd.keys())
コード例 #26
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_pyldavis_data():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    data = s.pyldavis_data()
    index_of_a = data['vocab'].index('a')
    index_of_c = data['vocab'].index('c')
    assert_equals(data['term_frequency'][index_of_a], 1)
    assert_equals(data['term_frequency'][index_of_c], 2)
    for dist in data['topic_term_dists']:
        assert_almost_equals(sum(dist), 1)
    for dist in data['doc_topic_dists']:
        assert_almost_equals(sum(dist), 1)
コード例 #27
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_serialize_simple():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    m = s.serialize()
    s2 = deserialize(defn, m)
    assert s2.__class__ == s.__class__
    assert all(word in "abcdef"
               for wd in s2.word_distribution_by_topic()
               for word in wd.keys())
    assert all(isinstance(word, str)
               for wd in s2.word_distribution_by_topic()
               for word in wd.keys())
コード例 #28
0
ファイル: test_reuters.py プロジェクト: zhongyunuestc/lda-2
    def test_lda_random_seed(self):
        # ensure that randomness is contained in rng
        # by running model twice with same seed
        niters = 10

        # model 1
        prng1 = rng(seed=54321)
        latent1 = model.initialize(self.defn, self.docs, prng1)
        runner1 = runner.runner(self.defn, self.docs, latent1)
        runner1.run(prng1, niters)

        # model2
        prng2 = rng(seed=54321)
        latent2 = model.initialize(self.defn, self.docs, prng2)
        runner2 = runner.runner(self.defn, self.docs, latent2)
        runner2.run(prng2, niters)

        assert_list_equal(latent1.topic_distribution_by_document(),
                          latent2.topic_distribution_by_document())

        for d1, d2 in zip(latent1.word_distribution_by_topic(),
                          latent2.word_distribution_by_topic()):
            assert_dict_equal(d1, d2)
コード例 #29
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_serialize_pickle():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    # Pickle
    bstr = pickle.dumps(s)
    s2 = pickle.loads(bstr)
    assert s2.__class__ == s.__class__

    # cPickle
    bstr = cPickle.dumps(s)
    s2 = cPickle.loads(bstr)
    assert s2.__class__ == s.__class__
コード例 #30
0
ファイル: test_reuters.py プロジェクト: datamicroscopes/lda
    def test_lda_random_seed(self):
        # ensure that randomness is contained in rng
        # by running model twice with same seed
        niters = 10

        # model 1
        prng1 = rng(seed=54321)
        latent1 = model.initialize(self.defn, self.docs, prng1)
        runner1 = runner.runner(self.defn, self.docs, latent1)
        runner1.run(prng1, niters)

        # model2
        prng2 = rng(seed=54321)
        latent2 = model.initialize(self.defn, self.docs, prng2)
        runner2 = runner.runner(self.defn, self.docs, latent2)
        runner2.run(prng2, niters)

        assert_list_equal(latent1.topic_distribution_by_document(),
                          latent2.topic_distribution_by_document())

        for d1, d2 in zip(latent1.word_distribution_by_topic(),
                          latent2.word_distribution_by_topic()):
            assert_dict_equal(d1, d2)
コード例 #31
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_relevance():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    s.term_relevance_by_topic(weight=0)
    s.term_relevance_by_topic(weight=1)
    rel = s.term_relevance_by_topic()
    assert isinstance(rel, list)
    assert isinstance(rel[0], list)
    assert len(rel) == s.ntopics()
    assert len(rel[0]) == s.nwords()
    assert rel[0] == sorted(rel[0], key=lambda (_, r): r, reverse=True)
    assert rel[-1][0] < rel[-1][-1]
コード例 #32
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_pyldavis_data():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    data = s.pyldavis_data()
    index_of_a = data['vocab'].index('a')
    index_of_c = data['vocab'].index('c')
    assert_equals(data['term_frequency'][index_of_a], 1)
    assert_equals(data['term_frequency'][index_of_c], 2)
    for dist in data['topic_term_dists']:
        assert_almost_equals(sum(dist), 1)
    for dist in data['doc_topic_dists']:
        assert_almost_equals(sum(dist), 1)
コード例 #33
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_serialize_pickle():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    # Pickle
    bstr = pickle.dumps(s)
    s2 = pickle.loads(bstr)
    assert s2.__class__ == s.__class__

    # cPickle
    bstr = cPickle.dumps(s)
    s2 = cPickle.loads(bstr)
    assert s2.__class__ == s.__class__
コード例 #34
0
ファイル: test_runner.py プロジェクト: zhongyunuestc/lda-2
def test_runner_second_dp_valid():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    prng = rng()
    latent = model.initialize(defn, data, prng)
    old_beta = latent.beta
    old_gamma = latent.gamma
    kernels = ['crf'] + \
        runner.second_dp_hp_kernel_config(defn)
    r = runner.runner(defn, data, latent, kernels)
    r.run(prng, 10)
    assert_almost_equals(latent.beta, old_beta)
    assert_almost_equals(latent.gamma, old_gamma)
    assert latent.alpha > 0
コード例 #35
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_relevance():
    docs = [list('abcd'), list('cdef')]
    defn = model_definition(len(docs), v=6)
    prng = rng()
    s = initialize(defn, docs, prng)
    s.term_relevance_by_topic(weight=0)
    s.term_relevance_by_topic(weight=1)
    rel = s.term_relevance_by_topic()
    assert isinstance(rel, list)
    assert isinstance(rel[0], list)
    assert len(rel) == s.ntopics()
    assert len(rel[0]) == s.nwords()
    assert rel[0] == sorted(rel[0],
                            key=lambda (_, r): r,
                            reverse=True)
    assert rel[-1][0] < rel[-1][-1]
コード例 #36
0
ファイル: test_state.py プロジェクト: mrG7/lda
def test_serialize_pickle():
    N, V = 10, 20
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = data
    prng = rng()
    s = initialize(defn, view, prng)
    # Pickle
    bstr = pickle.dumps(s)
    s2 = pickle.loads(bstr)
    assert s2.__class__ == s.__class__

    # cPickle
    bstr = cPickle.dumps(s)
    s2 = cPickle.loads(bstr)
    assert s2.__class__ == s.__class__
コード例 #37
0
ファイル: test_state.py プロジェクト: jzf2101/lda
def test_explicit():
    N, V = 5, 100
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = numpy_dataview(data)
    R = rng()

    table_assignments = [
        np.random.randint(low=0, high=10, size=len(d)) for d in data]

    dish_assignments = [
        np.random.randint(low=0, high=len(t)) for t in table_assignments]

    s = initialize(defn, view, R,
        table_assignments=table_assignments,
        dish_assignments=dish_assignments)
    assert_equals(s.nentities(), len(data))
コード例 #38
0
def test_explicit():
    N, V = 5, 100
    defn = model_definition(N, V)
    data = toy_dataset(defn)
    view = numpy_dataview(data)
    R = rng()

    table_assignments = [
        np.random.randint(low=0, high=10, size=len(d)) for d in data
    ]

    dish_assignments = [
        np.random.randint(low=0, high=len(t)) for t in table_assignments
    ]

    s = initialize(defn,
                   view,
                   R,
                   table_assignments=table_assignments,
                   dish_assignments=dish_assignments)
    assert_equals(s.nentities(), len(data))
コード例 #39
0
ファイル: test_state.py プロジェクト: datamicroscopes/lda
def test_explicit():
    """Test that we can explicitly initialize state by specifying
    table and dish assignments
    """
    N, V = 3, 7
    defn = model_definition(N, V)
    data = [[0, 1, 2, 3], [0, 1, 4], [0, 1, 5, 6]]

    table_assignments = [[1, 2, 1, 2], [1, 1, 1], [3, 3, 3, 1]]
    dish_assignments = [[0, 1, 2], [0, 3], [0, 1, 2, 1]]

    s = initialize(defn, data,
                   table_assignments=table_assignments,
                   dish_assignments=dish_assignments)
    assert_equals(s.nentities(), len(data))
    assert len(s.dish_assignments()) == len(dish_assignments)
    assert len(s.table_assignments()) == len(table_assignments)
    for da1, da2 in zip(s.dish_assignments(), dish_assignments):
        assert da1 == da2
    for ta1, ta2 in zip(s.table_assignments(), table_assignments):
        assert ta1 == ta2
コード例 #40
0
ファイル: test_state.py プロジェクト: zhongyunuestc/lda-2
def test_explicit():
    """Test that we can explicitly initialize state by specifying
    table and dish assignments
    """
    N, V = 3, 7
    defn = model_definition(N, V)
    data = [[0, 1, 2, 3], [0, 1, 4], [0, 1, 5, 6]]

    table_assignments = [[1, 2, 1, 2], [1, 1, 1], [3, 3, 3, 1]]
    dish_assignments = [[0, 1, 2], [0, 3], [0, 1, 2, 1]]

    s = initialize(defn,
                   data,
                   table_assignments=table_assignments,
                   dish_assignments=dish_assignments)
    assert_equals(s.nentities(), len(data))
    assert len(s.dish_assignments()) == len(dish_assignments)
    assert len(s.table_assignments()) == len(table_assignments)
    for da1, da2 in zip(s.dish_assignments(), dish_assignments):
        assert da1 == da2
    for ta1, ta2 in zip(s.table_assignments(), table_assignments):
        assert ta1 == ta2
コード例 #41
0
def test_convergence_simple():
    N, V = 2, 10
    defn = model_definition(N, V)
    data = [
        np.array([5, 6]),
        np.array([0, 1, 2]),
    ]
    view = numpy_dataview(data)
    prng = rng()

    scores = []
    idmap = {}
    for i, (tables, dishes) in enumerate(permutations([2, 3])):
        latent = model.initialize(defn,
                                  view,
                                  prng,
                                  table_assignments=tables,
                                  dish_assignments=dishes)
        scores.append(latent.score_assignment() + latent.score_data(prng))
        idmap[(tables, dishes)] = i
    true_dist = scores_to_probs(scores)

    def kernel(latent):
        # mutates latent in place
        doc_model = model.bind(latent, data=view)
        kernels.assign2(doc_model, prng)
        for did in xrange(latent.nentities()):
            table_model = model.bind(latent, document=did)
            kernels.assign(table_model, prng)

    latent = model.initialize(defn, view, prng)

    skip = 10

    def sample_fn():
        for _ in xrange(skip):
            kernel(latent)
        table_assignments = latent.table_assignments()
        canon_table_assigments = tuple(
            map(tuple, map(permutation_canonical, table_assignments)))

        dish_maps = latent.dish_assignments()
        dish_assignments = []
        for dm, (ta, ca) in zip(dish_maps,
                                zip(table_assignments,
                                    canon_table_assigments)):
            dish_assignment = []
            for t, c in zip(ta, ca):
                if c == len(dish_assignment):
                    dish_assignment.append(dm[t])
            dish_assignments.append(dish_assignment)

        canon_dish_assigments = tuple(
            map(tuple, map(permutation_canonical, dish_assignments)))

        return idmap[(canon_table_assigments, canon_dish_assigments)]

    assert_discrete_dist_approx(sample_fn,
                                true_dist,
                                ntries=100,
                                nsamples=10000,
                                kl_places=2)