Ejemplo n.º 1
0
 def test_combiner_api_compatibility_int_mode(self):
   data = np.array([["earth", "wind", "and", "fire"],
                    ["earth", "wind", "and", "michigan"]])
   combiner = text_vectorization._TextVectorizationCombiner(compute_idf=False)
   expected = {
       "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
   }
   self.validate_accumulator_serialize_and_deserialize(combiner, data,
                                                       expected)
   self.validate_accumulator_uniqueness(combiner, data)
Ejemplo n.º 2
0
 def test_combiner_computation(self,
                               data,
                               vocab_size,
                               expected_accumulator_output,
                               expected_extract_output,
                               compute_idf=True):
   combiner = text_vectorization._TextVectorizationCombiner(
       vocab_size=vocab_size, compute_idf=compute_idf)
   expected_accumulator = combiner._create_accumulator(
       **expected_accumulator_output)
   self.validate_accumulator_computation(combiner, data, expected_accumulator)
   self.validate_accumulator_extract(combiner, data, expected_extract_output)
Ejemplo n.º 3
0
 def test_combiner_api_compatibility_tfidf_mode(self):
   data = np.array([["earth", "wind", "and", "fire"],
                    ["earth", "wind", "and", "michigan"]])
   combiner = text_vectorization._TextVectorizationCombiner(compute_idf=True)
   expected_extract_output = {
       "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
       "idf": np.array([0.510826, 0.510826, 0.510826, 0.693147, 0.693147]),
       "oov_idf": np.array([1.098612])
   }
   expected_accumulator_output = {
       "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
       "counts": np.array([2, 2, 2, 1, 1]),
       "document_counts": np.array([2, 2, 2, 1, 1]),
       "num_documents": np.array(1),
   }
   self.validate_accumulator_serialize_and_deserialize(
       combiner, data, expected_accumulator_output)
   self.validate_accumulator_uniqueness(combiner, data)
   self.validate_accumulator_extract(combiner, data, expected_extract_output)