示例#1
0
def test__parse_feature_reduction():
    assert Parameters._parse_feature_reduction('None') is None
    values = ['LDA', 'MDS']
    for value in values:
        assert Parameters._parse_feature_reduction(value) is value
    with pytest.raises(AssertionError):
        Parameters._parse_feature_reduction('invalid_value')
def main(argv):
    verify_python_version()
    n_args = len(argv)
    mode = argv[1] if 1 < n_args else None
    if mode == "--trainer" and n_args == 3:
        config_file = argv[2]
        parameters = Parameters(config_file)
        os.makedirs(parameters.data_dir, exist_ok=True)
        old_dir = os.getcwd()
        os.chdir(parameters.data_dir)
        trainer.main(parameters)
        os.chdir(old_dir)
    elif mode == "--prediction_server" and n_args == 4:
        config_file = argv[2]
        port = int(argv[3])
        parameters = Parameters(config_file)
        os.makedirs(parameters.data_dir, exist_ok=True)
        old_dir = os.getcwd()
        os.chdir(parameters.data_dir)
        prediction_server.main(parameters, port)
        os.chdir(old_dir)
    else:
        print(
            "Usage: python3 -m text_categorizer --trainer <configuration file>"
        )
        print(
            "       python3 -m text_categorizer --prediction_server <configuration file> <port>"
        )
示例#3
0
def test__parse_class_weights():
    assert Parameters._parse_class_weights('None') is None
    values = ['balanced']
    for value in values:
        assert Parameters._parse_class_weights(value) is value
    with pytest.raises(AssertionError):
        Parameters._parse_feature_reduction('invalid_value')
示例#4
0
def test__parse_resampling():
    assert Parameters._parse_resampling('None') is None
    values = ['RandomOverSample', 'RandomUnderSample']
    for value in values:
        assert Parameters._parse_resampling(value) is value
    with pytest.raises(AssertionError):
        Parameters._parse_resampling('invalid_value')
示例#5
0
def test__parse_number_of_jobs():
    assert Parameters._parse_number_of_jobs('None') == 1
    for i in range(1, 5):
        assert Parameters._parse_number_of_jobs(str(i)) == i
    for i in range(-4, 0):
        assert Parameters._parse_number_of_jobs(str(i)) == cpu_count() + 1 + i
    with pytest.raises(AssertionError):
        Parameters._parse_number_of_jobs(0)
示例#6
0
def test__parse_vectorizer():
    vectorizers = [
        TfidfVectorizer.__name__,
        CountVectorizer.__name__,
        HashingVectorizer.__name__,
        DocumentPoolEmbeddings.__name__,
    ]
    for vectorizer in vectorizers:
        assert Parameters._parse_vectorizer(vectorizer) == vectorizer
    with pytest.raises(AssertionError):
        Parameters._parse_vectorizer('invalid_vectorizer')
def test_generate_report():
    execution_info = pd.DataFrame.from_dict({
        'Start': [functions.get_local_time_str()],
        'End': [functions.get_local_time_str()],
    })
    parameters_dict = Parameters(utils.config_file).__dict__
    predictions_dict = {
        'y_true': ['label1'],
        'classifier_key': [{'label1': 0.0, 'label2': 1.0}],
    }
    parameters_dict['set_num_accepted_probs'] = 1
    expected_df_row0 = pd.concat([
        execution_info,
        functions.parameters_to_data_frame(parameters_dict),
        functions.predictions_to_data_frame(predictions_dict, 1),
    ], axis=1)
    parameters_dict['set_num_accepted_probs'] = {1}
    excel_file1 = utils.generate_available_filename(ext='.xlsx')
    excel_file2 = utils.generate_available_filename(ext='.xlsx')
    expected_df = pd.DataFrame()
    try:
        for i, file_exists in enumerate([False, True]):
            assert exists(excel_file1) is file_exists
            df = functions.generate_report(execution_info, parameters_dict, predictions_dict, excel_file1)
            df.to_excel(excel_file2, index=False)
            assert df.shape == (i + 1, 44)
            expected_df = pd.concat([expected_df, expected_df_row0])
            pd.util.testing.assert_frame_equal(df, expected_df)
            pd.util.testing.assert_frame_equal(pd.read_excel(excel_file1), pd.read_excel(excel_file2))
    finally:
        utils.remove_and_check(excel_file1)
        utils.remove_and_check(excel_file2)
def test_parameters_to_data_frame():
    expected_dict = {
        'Excel file': abspath('example_excel_file.xlsx'),
        'Text column': 'Example column',
        'Label column': 'Classification column',
        'n_jobs': cpu_count(),
        'Preprocessed data file': abspath('./data/preprocessed_data.pkl'),
        'Data directory': abspath('./data'),
        'Final training': False,
        'Preprocess data': True,
        'MosesTokenizer language code': 'en',
        'Spell checker language': 'None',
        'NLTK stop words package': 'english',
        'Document adjustment code': abspath('text_categorizer/document_updater.py'),
        'Vectorizer': 'TfidfVectorizer',
        'Feature reduction': 'None',
        'Remove adjectives': False,
        'Synonyms file': 'None',
        'Accepted probabilities': {1,2,3},
        'Test size': 0.3,
        'Force subsets regeneration': False,
        'Resampling': 'None',
        'Class weights': 'None',
        'Generate ROC plots': False,
    }
    p = Parameters(utils.config_file)
    df = functions.parameters_to_data_frame(p.__dict__)
    assert df.shape == (1, 22)
    assert df.iloc[0].to_dict() == expected_dict
示例#9
0
def test_main():
    old_dir = os.getcwd()
    new_dir = utils.generate_available_filename()
    base_parameters = Parameters(utils.config_file)
    base_parameters.preprocessed_data_file = os.path.basename(base_parameters.preprocessed_data_file)
    try:
        os.makedirs(new_dir, exist_ok=False)
        os.chdir(new_dir)
        parameters = deepcopy(base_parameters)
        parameters.excel_file = "invalid_excel_file"
        parameters.preprocessed_data_file = "invalid_data_file"
        with pytest.raises(SystemExit):
            trainer.main(parameters)
        parameters = deepcopy(base_parameters)
        assert not os.path.exists(parameters.preprocessed_data_file)
        try:
            trainer.main(parameters)
            assert os.path.exists(parameters.preprocessed_data_file)
            assert os.path.exists("predictions.json")
            assert os.path.exists("report.xlsx")
        finally:
            utils.remove_and_check(parameters.preprocessed_data_file)
            utils.remove_and_check("predictions.json")
            utils.remove_and_check("report.xlsx")
        parameters.excel_file = os.path.abspath("20newsgroups")
        parameters.preprocess_data = False
        excel_file_20newsgroups = "20newsgroups.xlsx"
        assert not os.path.exists(excel_file_20newsgroups)
        try:
            trainer.main(parameters)
            pytest.fail()
        except SystemExit:
            assert os.path.exists(excel_file_20newsgroups)
        finally:
            utils.remove_and_check(excel_file_20newsgroups)
        parameters = deepcopy(base_parameters)
        parameters.final_training = True
        try:
            trainer.main(parameters)
        finally:
            assert not os.path.exists("predictions.json")
            assert not os.path.exists("report.xlsx")
            utils.remove_and_check(parameters.preprocessed_data_file)
    finally:
        os.chdir(old_dir)
        rmtree(new_dir)
示例#10
0
def test_main(monkeypatch):
    parameters = Parameters(utils.config_file)
    with pytest.raises(SystemExit):
        prediction_server.main(parameters, 1024)
    with monkeypatch.context() as m:
        m.setattr("gevent.pywsgi.WSGIServer.serve_forever",
                  lambda stop_timeout: None)
        try:
            vectorizer_file = 'vectorizer.pkl'
            dump(
                FeatureExtractor(vectorizer_name='TfidfVectorizer').vectorizer,
                vectorizer_file)
            assert prediction_server._old_handlers == dict()
            assert prediction_server.logger.disabled is False
            prediction_server.main(parameters, 1025)
            assert prediction_server.logger.disabled is True
            assert prediction_server._text_field == 'Example column'
            assert prediction_server._class_field == 'Classification column'
            assert prediction_server._preprocessor.mosestokenizer_language_code == 'en'
            assert prediction_server._preprocessor.store_data is False
            assert prediction_server._preprocessor.spell_checker is None
            #assert prediction_server._preprocessor.spell_checker.hunspell.max_threads == cpu_count()
            assert len(prediction_server._feature_extractor.stop_words) > 0
            assert prediction_server._feature_extractor.feature_reduction is None
            assert prediction_server._feature_extractor.document_adjustment_code.__file__ == abspath(
                'text_categorizer/document_updater.py')
            assert prediction_server._feature_extractor.synonyms is None
            assert prediction_server._feature_extractor.vectorizer_file == vectorizer_file
            assert prediction_server._feature_extractor.n_jobs == cpu_count()
            assert prediction_server._old_handlers == dict()
            m.setattr(
                "text_categorizer.prediction_server._reset_signal_handlers",
                lambda: None)
            prediction_server.main(parameters, 1025)
            assert len(prediction_server._old_handlers) == 1
            prediction_server._old_handlers.clear()
            assert type(prediction_server.app.wsgi_app) is WSGIServer
            assert prediction_server.app.wsgi_app.started is False
            assert prediction_server.app.wsgi_app.closed is True
            for sig in constants.stop_signals:
                prediction_server.app.wsgi_app.start()
                assert prediction_server.app.wsgi_app.started is True
                prediction_server._signal_handler(sig=sig, frame=None)
                assert prediction_server.app.wsgi_app.closed is True
            for sig in constants.stop_signals * 2:
                prediction_server.app.wsgi_app.start()
                assert prediction_server.app.wsgi_app.started is True
                prediction_server._signal_handler(sig=sig, frame=None)
                assert prediction_server.app.wsgi_app.closed is True
            for sig in [signal.SIGILL]:
                assert sig not in constants.stop_signals
                prediction_server.app.wsgi_app.start()
                assert prediction_server.app.wsgi_app.started is True
                prediction_server._signal_handler(sig=sig, frame=None)
                assert prediction_server.app.wsgi_app.closed is False
        finally:
            utils.remove_and_check(vectorizer_file)
示例#11
0
def test__parse_accepted_probs():
    assert Parameters._parse_accepted_probs('1,2,3,2') == {1, 2, 3}
    with pytest.raises(ValueError):
        Parameters._parse_accepted_probs('')
    with pytest.raises(AssertionError):
        Parameters._parse_accepted_probs('1,0,2,3')
    with pytest.raises(AssertionError):
        Parameters._parse_accepted_probs('1,-1,2,3')
示例#12
0
def test_load_20newsgroups():
    p1 = Parameters(utils.config_file)
    p1.excel_file = '20newsgroups'
    excel_file = utils.generate_available_filename('.xlsx')
    try:
        p2 = trainer.load_20newsgroups(p1, excel_file)
        assert p1 is not p2
        assert p1 != p2
        assert p2.excel_column_with_text_data == 'data'
        assert p2.excel_column_with_classification_data == 'target'
        assert os.path.exists(excel_file)
        df = pd.read_excel(excel_file)
        assert df.shape == (18846, 3)
        assert list(df.keys()) == ['Unnamed: 0', 'data', 'target']
        expected_mtime = os.path.getmtime(excel_file)
        p3 = trainer.load_20newsgroups(p1, excel_file)
        assert os.path.getmtime(excel_file) == expected_mtime
        assert p3.__dict__ == p2.__dict__
    finally:
        utils.remove_and_check('20news-bydate_py3.pkz')
        utils.remove_and_check(excel_file)
示例#13
0
def test__parse_classifiers():
    clfs = [
        classifiers.RandomForestClassifier,
        classifiers.BernoulliNB,
        classifiers.MultinomialNB,
        classifiers.ComplementNB,
        classifiers.KNeighborsClassifier,
        classifiers.MLPClassifier,
        classifiers.LinearSVC,
        classifiers.DecisionTreeClassifier,
        classifiers.ExtraTreeClassifier,
        classifiers.DummyClassifier,
        classifiers.SGDClassifier,
        classifiers.BaggingClassifier,
    ]
    clfs_str = ','.join([clf.__name__ for clf in clfs])
    assert Parameters._parse_classifiers(clfs_str) == clfs
    with pytest.raises(AssertionError):
        Parameters._parse_classifiers('')
    clfs_str = ','.join([clfs_str, 'invalid_value'])
    with pytest.raises(AssertionError):
        Parameters._parse_classifiers(clfs_str)
示例#14
0
def test___init__():
    expected_dict = {
        'excel_file':
        abspath(utils.example_excel_file),
        'excel_column_with_text_data':
        'Example column',
        'excel_column_with_classification_data':
        'Classification column',
        'nltk_stop_words_package':
        'english',
        'number_of_jobs':
        cpu_count(),
        'mosestokenizer_language_code':
        'en',
        'preprocessed_data_file':
        abspath('./data/preprocessed_data.pkl'),
        'preprocess_data':
        True,
        'document_adjustment_code':
        abspath('./text_categorizer/document_updater.py'),
        'vectorizer':
        'TfidfVectorizer',
        'feature_reduction':
        None,
        'set_num_accepted_probs': {1, 2, 3},
        'classifiers':
        test_classifiers.clfs,
        'test_subset_size':
        0.3,
        'force_subsets_regeneration':
        False,
        'remove_adjectives':
        False,
        'synonyms_file':
        None,
        'resampling':
        None,
        'class_weights':
        None,
        'generate_roc_plots':
        False,
        'spell_checker_lang':
        None,
        'final_training':
        False,
        'data_dir':
        abspath('./data'),
    }
    parameters = Parameters(utils.config_file)
    assert parameters.__dict__ == expected_dict
示例#15
0
def test_main(monkeypatch, capsys):
    all_args = [['text_categorizer'], ['text_categorizer', 'invalid_arg'],
                ['text_categorizer', '--trainer', config_file, 'invalid_arg'],
                [
                    'text_categorizer', '--prediction_server', config_file,
                    '5000', 'invalid_arg'
                ], ['text_categorizer', '--trainer', config_file],
                [
                    'text_categorizer', '--prediction_server', config_file,
                    '5000'
                ]]
    parameters = Parameters(config_file)
    data_dir = parameters.data_dir
    data_dir_already_existed = exists(data_dir)
    with monkeypatch.context() as m:
        trainer_main_code = trainer.main.__code__
        prediction_server_main_code = prediction_server.main.__code__
        assert trainer_main_code.co_varnames[0:trainer_main_code.
                                             co_argcount] == ('parameters', )
        assert prediction_server_main_code.co_varnames[
            0:prediction_server_main_code.co_argcount] == (
                'parameters',
                'port',
            )
        m.setattr("text_categorizer.trainer.main", lambda parameters: None)
        m.setattr("text_categorizer.prediction_server.main",
                  lambda parameters, port: None)
        for i in range(len(all_args)):
            argv = all_args[i]
            __main__.main(argv)
            captured = capsys.readouterr()
            assert captured.out == (_expected_usage
                                    if i < len(all_args) - 2 else '')
            assert captured.err == ''
    assert exists(data_dir)
    if not data_dir_already_existed:
        rmtree(data_dir)
示例#16
0
def test__parse_None():
    values = ['1', 2, None]
    for value in values:
        assert Parameters._parse_None(value) is value
    assert Parameters._parse_None('None') is None
示例#17
0
def test__parse_synonyms_file():
    assert Parameters._parse_synonyms_file('None') is None
    filename = 'valid_filename'
    assert Parameters._parse_synonyms_file(filename) == abspath(filename)