def test_evaluation_gensim_all_metrics():
    passed_params = {'num_topics', 'update_every', 'passes', 'iterations'}
    varying_params = [dict(num_topics=k) for k in range(2, 5)]
    const_params = dict(update_every=0, passes=1, iterations=1)

    eval_res = tm_gensim.evaluate_topic_models(EVALUATION_TEST_DTM, varying_params, const_params,
                                               metric=tm_gensim.AVAILABLE_METRICS,
                                               coherence_gensim_texts=EVALUATION_TEST_TOKENS,
                                               coherence_gensim_kwargs={
                                                   'dictionary': evaluate.FakedGensimDict.from_vocab(EVALUATION_TEST_VOCAB)
                                               },
                                               return_models=True)

    assert len(eval_res) == len(varying_params)

    for param_set, metric_results in eval_res:
        assert set(param_set.keys()) == passed_params
        assert set(metric_results.keys()) == set(tm_gensim.AVAILABLE_METRICS + ('model',))

        assert metric_results['perplexity'] > 0
        assert 0 <= metric_results['cao_juan_2009'] <= 1
        assert metric_results['coherence_mimno_2011'] < 0
        assert np.isclose(metric_results['coherence_gensim_u_mass'], metric_results['coherence_mimno_2011'])
        assert 0 <= metric_results['coherence_gensim_c_v'] <= 1
        assert metric_results['coherence_gensim_c_uci'] < 0
        assert metric_results['coherence_gensim_c_npmi'] < 0
    def toolkit_cv_plot(self, varying_params, 
                         constant_params,
                         save_plot=True,
                         save_dir='results/model_validation',
                         filename='',
                         ext='.pdf', 
                         size=(20, 15),
                         **kwargs):
        '''
        Using tmtoolkit for parameter tuning based on a wider variety of measures
        '''

        warnings.filterwarnings("ignore", category = UserWarning)   

        print('evaluating {} topic models'.format(len(varying_params)))
        eval_results = tm_gensim.evaluate_topic_models((self.gensim_dict, 
                                                        self.bow), 
                                                        varying_params, 
                                                        constant_params,
                                                        coherence_gensim_texts=self.text,
                                                        **kwargs)  

        results_by_n_topics = results_by_parameter(eval_results, 'num_topics')
        plot_eval_results(results_by_n_topics, xaxislabel='num topics',
                  title='Evaluation results', figsize=size);
          
        if save_plot:
            filename = 'tmtoolkit_CV_'                                                     
            full_path = save_folder_file(save_dir, filename, ext=ext, 
                                         optional_folder='convergence_plots')
      
            plt.savefig(full_path)
        return(results_by_n_topics)    
Example #3
0
# evaluate topic models with different parameters
const_params = dict(
    update_every=0,
    passes=20,
    iterations=400,
    alpha='auto',
    eta='auto',
)
ks = list(range(10, 140, 10)) + list(range(140, 200, 20))
varying_params = [dict(num_topics=k, alpha=1.0 / k) for k in ks]

print('evaluating %d topic models' % len(varying_params))
eval_results = tm_gensim.evaluate_topic_models(
    (gnsm_dict, gnsm_corpus),
    varying_params,
    const_params,
    coherence_gensim_texts=model_lists)  # necessary for coherence C_V metric

# save the results as pickle
print('saving results')
pickle_data(eval_results, 'gensim_evaluation_results_entire.pickle')

# plot the results
print('plotting evaluation results')
plt.style.use('ggplot')
results_by_n_topics = results_by_parameter(eval_results, 'num_topics')
plot_eval_results(results_by_n_topics,
                  xaxislabel='num. topics k',
                  title='Evaluation results',
                  figsize=(8, 6))