def test_combiner_computation(self, data, vocab_size, expected_accumulator_output, expected_extract_output): combiner = index_lookup._IndexLookupCombiner(vocab_size=vocab_size) expected_accumulator = combiner._create_accumulator() expected_accumulator = self.update_accumulator(expected_accumulator, expected_accumulator_output) self.validate_accumulator_computation(combiner, data, expected_accumulator) self.validate_accumulator_extract(combiner, data, expected_extract_output)
def test_combiner_api_compatibility_int_mode(self): data = np.array([[42, 1138, 725, 1729], [42, 1138, 725, 203]]) combiner = index_lookup._IndexLookupCombiner() expected_accumulator_output = { "vocab": np.array([1138, 725, 42, 1729, 203]), "counts": np.array([2, 2, 2, 1, 1]), } expected_extract_output = { "vocab": np.array([1138, 725, 42, 1729, 203]), } expected_accumulator = combiner._create_accumulator() expected_accumulator = self.update_accumulator(expected_accumulator, expected_accumulator_output) self.validate_accumulator_serialize_and_deserialize(combiner, data, expected_accumulator) self.validate_accumulator_uniqueness(combiner, data) self.validate_accumulator_extract(combiner, data, expected_extract_output)
def test_combiner_api_compatibility_int_mode(self): data = np.array([["earth", "wind", "and", "fire"], ["earth", "wind", "and", "michigan"]]) combiner = index_lookup._IndexLookupCombiner() expected_accumulator_output = { "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]), "counts": np.array([2, 2, 2, 1, 1]), } expected_extract_output = { "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]), } expected_accumulator = combiner._create_accumulator() expected_accumulator = self.update_accumulator(expected_accumulator, expected_accumulator_output) self.validate_accumulator_serialize_and_deserialize(combiner, data, expected_accumulator) self.validate_accumulator_uniqueness(combiner, data) self.validate_accumulator_extract(combiner, data, expected_extract_output)