def test_compute_models_parallel_lda_multi_vs_singleproc():
    passed_params = {'n_topics', 'n_iter', 'random_state'}
    varying_params = [dict(n_topics=k) for k in range(2, 5)]
    const_params = dict(n_iter=3, random_state=1)

    models = tm_lda.compute_models_parallel(EVALUATION_TEST_DTM, varying_params, const_params)
    assert len(models) == len(varying_params)

    for param_set, model in models:
        assert set(param_set.keys()) == passed_params
        assert isinstance(model, lda.LDA)
        assert isinstance(model.doc_topic_, np.ndarray)
        assert isinstance(model.topic_word_, np.ndarray)

    models_singleproc = tm_lda.compute_models_parallel(EVALUATION_TEST_DTM, varying_params, const_params,
                                                       n_max_processes=1)

    assert len(models_singleproc) == len(models)
    for param_set2, model2 in models_singleproc:
        for x, y in models:
            if x == param_set2:
                param_set1, model1 = x, y
                break
        else:
            assert False

        assert np.allclose(model1.doc_topic_, model2.doc_topic_)
        assert np.allclose(model1.topic_word_, model2.topic_word_)
def test_compute_models_parallel_lda_multiple_docs():
    # 1 doc, no varying params
    const_params = dict(n_topics=3, n_iter=3, random_state=1)
    models = tm_lda.compute_models_parallel(EVALUATION_TEST_DTM, constant_parameters=const_params)
    assert len(models) == 1
    assert type(models) is list
    assert len(models[0]) == 2
    param1, model1 = models[0]
    assert param1 == const_params
    assert isinstance(model1, lda.LDA)
    assert isinstance(model1.doc_topic_, np.ndarray)
    assert isinstance(model1.topic_word_, np.ndarray)

    # 1 *named* doc, some varying params
    passed_params = {'n_topics', 'n_iter', 'random_state'}
    const_params = dict(n_iter=3, random_state=1)
    varying_params = [dict(n_topics=k) for k in range(2, 5)]
    docs = {'test1': EVALUATION_TEST_DTM}
    models = tm_lda.compute_models_parallel(docs, varying_params,
                                                     constant_parameters=const_params)
    assert len(models) == len(docs)
    assert isinstance(models, dict)
    assert set(models.keys()) == {'test1'}

    param_match = False
    for d, m in models.items():
        assert d == 'test1'
        assert len(m) == len(varying_params)
        for param_set, model in m:
            assert set(param_set.keys()) == passed_params
            assert isinstance(model, lda.LDA)
            assert isinstance(model.doc_topic_, np.ndarray)
            assert isinstance(model.topic_word_, np.ndarray)

            if param_set == param1:
                assert np.allclose(model.doc_topic_, model1.doc_topic_)
                assert np.allclose(model.topic_word_, model1.topic_word_)
                param_match = True

    assert param_match

    # n docs, no varying params
    const_params = dict(n_topics=3, n_iter=3, random_state=1)
    models = tm_lda.compute_models_parallel(EVALUATION_TEST_DTM_MULTI, constant_parameters=const_params)
    assert len(models) == len(EVALUATION_TEST_DTM_MULTI)
    assert isinstance(models, dict)
    assert set(models.keys()) == set(EVALUATION_TEST_DTM_MULTI.keys())

    for d, m in models.items():
        assert len(m) == 1
        for param_set, model in m:
            assert set(param_set.keys()) == set(const_params.keys())
            assert isinstance(model, lda.LDA)
            assert isinstance(model.doc_topic_, np.ndarray)
            assert isinstance(model.topic_word_, np.ndarray)

    # n docs, some varying params
    passed_params = {'n_topics', 'n_iter', 'random_state'}
    const_params = dict(n_iter=3, random_state=1)
    varying_params = [dict(n_topics=k) for k in range(2, 5)]
    models = tm_lda.compute_models_parallel(EVALUATION_TEST_DTM_MULTI, varying_params,
                                                     constant_parameters=const_params)
    assert len(models) == len(EVALUATION_TEST_DTM_MULTI)
    assert isinstance(models, dict)
    assert set(models.keys()) == set(EVALUATION_TEST_DTM_MULTI.keys())

    for d, m in models.items():
        assert len(m) == len(varying_params)
        for param_set, model in m:
            assert set(param_set.keys()) == passed_params
            assert isinstance(model, lda.LDA)
            assert isinstance(model.doc_topic_, np.ndarray)
            assert isinstance(model.topic_word_, np.ndarray)