コード例 #1
0
ファイル: test_trainer.py プロジェクト: zxlzr/cherry
 def test_init_pass_data(self, mock_train, mock_cache, mock_load, mock_stop_words):
     mock_stop_words.return_value = ['first', 'second']
     mock_load.return_value = self.x_data, self.y_data
     t = Trainer(
         model='harmful', vectorizer=None, vectorizer_method=None,
         clf=None, clf_method=None, x_data=self.x_data, y_data=self.y_data)
     mock_load.assert_not_called()
コード例 #2
0
ファイル: test_trainer.py プロジェクト: zxlzr/cherry
 def test_write_cache(self, mock_train, mock_stop_words):
     '''
     TODO: use create special cache files for testing
     '''
     mock_stop_words.return_value = ['first', 'second']
     t = Trainer(
         model='harmful', vectorizer=None, vectorizer_method=None,
         clf=None, clf_method=None, x_data=self.x_data, y_data=self.y_data)
     self.assertTrue(os.path.exists(os.path.join(DATA_DIR, 'harmful/trained.pkl')))
     self.assertTrue(os.path.exists(os.path.join(DATA_DIR, 'harmful/ve.pkl')))
コード例 #3
0
ファイル: test_trainer.py プロジェクト: wuqifhb/cherry
    def test_test_data_num_with_custom_split(self, mock_cut):
        mock_cut.return_value = ['警方', '发布', '了', '最新消息']

        def split_function(text):
            stop_word = ['但是']
            return [
                t for t in jieba.cut(text) if len(t) > 1
                and t not in stop_word]
        trainer = Trainer(test_num=0, lan='Chinese', split=split_function)
        self.assertEqual(
            trainer.test_num, 0)
コード例 #4
0
 def test_mock_init_call(self, mock_load_data, mock_get, mock_train,
                         mock_write_cache):
     meta_data_c = namedtuple('meta_data_c', ['data', 'target'])
     mock_load_data.return_value = meta_data_c(data=['random'], target=[2])
     mock_get.return_value = ('vectorizer', 'clf')
     candidates = [('English', 'Count', 'MNB'),
                   ('Chinese', 'Tfidf', 'Random')]
     for candidate in candidates:
         language, vectorizer_method, clf_method = candidate
         t = Trainer(model=self.foo_model,
                     language=language,
                     vectorizer_method=vectorizer_method,
                     clf_method=clf_method)
         mock_load_data.assert_called_with(self.foo_model,
                                           categories=None,
                                           encoding=None)
         mock_get.assert_called_with(language, None, None,
                                     vectorizer_method, clf_method)
         mock_train.assert_called_with('vectorizer', 'clf', ['random'], [2])
         mock_write_cache.assert_called_with('foo', 'clf', 'clf.pkz')
コード例 #5
0
 def test_train_default(self, mock_fix):
     x_data, y_data = ['random'], [2]
     t = Trainer._train(CountVectorizer, MultinomialNB, x_data, y_data)
     mock_fix.assert_called_with(['random'], [2])
コード例 #6
0
 def test_cache_not_found(self):
     with self.assertRaises(
             cherry.exceptions.FilesNotFoundError) as filesNotFoundError:
         t = Trainer(model='foo')
コード例 #7
0
ファイル: test_trainer.py プロジェクト: zxlzr/cherry
 def test_train_fit(self, mock_fit, mock_cache, mock_stop_words):
     mock_stop_words.return_value = ['first', 'second']
     t = Trainer(
         model='harmful', vectorizer=None, vectorizer_method=None,
         clf=None, clf_method=None, x_data=self.x_data, y_data=self.y_data)
     mock_fit.assert_called_once_with(self.x_data, self.y_data)
コード例 #8
0
ファイル: test_trainer.py プロジェクト: zxlzr/cherry
 def test_init_no_vectorizer(self, mock_train, mock_cache, mock_vect):
     t = Trainer(
         model='harmful', vectorizer=None, vectorizer_method=None,
         clf=None, clf_method=None, x_data=self.x_data, y_data=self.y_data)
     mock_vect.assert_called_once_with('harmful', None)
コード例 #9
0
ファイル: test_trainer.py プロジェクト: wuqifhb/cherry
 def test_empty_vocab_list(self, mock_cut):
     trainer = Trainer(test_num=1, lan='Chinese', split=None)
     self.assertEqual(
         trainer.vocab_set, set())
コード例 #10
0
ファイル: test_trainer.py プロジェクト: wuqifhb/cherry
 def test_test_data_num(self, mock_cut):
     trainer = Trainer(test_num=1, lan='Chinese', split=None)
     self.assertEqual(
         trainer.test_num, 1)