예제 #1
0
 def test_combiner_computation(self, data, vocab_size,
                               expected_accumulator_output,
                               expected_extract_output):
   combiner = index_lookup._IndexLookupCombiner(vocab_size=vocab_size)
   expected_accumulator = combiner._create_accumulator()
   expected_accumulator = self.update_accumulator(expected_accumulator,
                                                  expected_accumulator_output)
   self.validate_accumulator_computation(combiner, data, expected_accumulator)
   self.validate_accumulator_extract(combiner, data, expected_extract_output)
예제 #2
0
 def test_combiner_api_compatibility_int_mode(self):
   data = np.array([[42, 1138, 725, 1729], [42, 1138, 725, 203]])
   combiner = index_lookup._IndexLookupCombiner()
   expected_accumulator_output = {
       "vocab": np.array([1138, 725, 42, 1729, 203]),
       "counts": np.array([2, 2, 2, 1, 1]),
   }
   expected_extract_output = {
       "vocab": np.array([1138, 725, 42, 1729, 203]),
   }
   expected_accumulator = combiner._create_accumulator()
   expected_accumulator = self.update_accumulator(expected_accumulator,
                                                  expected_accumulator_output)
   self.validate_accumulator_serialize_and_deserialize(combiner, data,
                                                       expected_accumulator)
   self.validate_accumulator_uniqueness(combiner, data)
   self.validate_accumulator_extract(combiner, data, expected_extract_output)
예제 #3
0
 def test_combiner_api_compatibility_int_mode(self):
   data = np.array([["earth", "wind", "and", "fire"],
                    ["earth", "wind", "and", "michigan"]])
   combiner = index_lookup._IndexLookupCombiner()
   expected_accumulator_output = {
       "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
       "counts": np.array([2, 2, 2, 1, 1]),
   }
   expected_extract_output = {
       "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
   }
   expected_accumulator = combiner._create_accumulator()
   expected_accumulator = self.update_accumulator(expected_accumulator,
                                                  expected_accumulator_output)
   self.validate_accumulator_serialize_and_deserialize(combiner, data,
                                                       expected_accumulator)
   self.validate_accumulator_uniqueness(combiner, data)
   self.validate_accumulator_extract(combiner, data, expected_extract_output)