Beispiel #1
0
 def test_set_params(self):
     # Set the arguments in container classes
     tw = TextWiser(Embedding.TfIdf(min_df=5),
                    Transformation.NMF(n_components=30),
                    lazy_load=True)
     tw.set_params(embedding__min_df=10,
                   transformations__0__n_components=10)
     self.assertEqual(tw.embedding.min_df, 10)
     self.assertEqual(tw.transformations[0].n_components, 10)
     # Set the arguments in implementation
     tw = TextWiser(Embedding.Doc2Vec(vector_size=2, min_count=1,
                                      workers=1))
     tw.fit(docs)
     tw.set_params(_imp__0__seed=10)
     self.assertEqual(tw._imp[0].seed, 10)
     # Set the arguments in a schema
     schema = {'transform': ['tfidf', ['nmf', {'n_components': 30}]]}
     tw = TextWiser(Embedding.Compound(schema=schema))
     tw.set_params(embedding__schema__transform__0__min_df=10,
                   embedding__schema__transform__1__n_components=10)
     self.assertEqual(tw.embedding.schema['transform'][0][1]['min_df'], 10)
     self.assertEqual(
         tw.embedding.schema['transform'][1][1]['n_components'], 10)
     # Replace a part of the schema in a list
     tw.set_params(embedding__schema__transform__0='bow')
     self.assertEqual(tw.embedding.schema['transform'][0], 'bow')
     # Replace a part of the schema
     tw.set_params(embedding__schema__transform=['bow'])
     self.assertEqual(tw.embedding.schema['transform'][0], 'bow')
Beispiel #2
0
 def test_fit_transform(self):
     tw = TextWiser(Embedding.TfIdf(min_df=2),
                    Transformation.NMF(n_components=2),
                    dtype=torch.float32)
     expected = torch.tensor(
         [[0.8865839243, 0.0000000000], [0.6736079454, 0.5221673250],
          [0.0203559380, 1.1122620106]],
         dtype=torch.float32)
     self._test_fit_transform(tw, expected)
     self._reset_seed()
     self._test_fit_before_transform(tw, expected, atol=1e-5)
Beispiel #3
0
 def test_min_components(self):
     with self.assertRaises(ValueError):
         TextWiser(Embedding.TfIdf(min_df=2),
                   Transformation.NMF(n_components=1),
                   dtype=torch.float32)