def test_MaxTokenSelector( self, max_nb_tokens, expected_columns, ): # pylint: disable=g-doc-args r"""Tests HeuristicExactMatchTokenSelection. - Test scoring: The columns are ordered 2, 0, 1. - The cleaned tokens are well created: remove punctuation. If not the order changes and becomes 0, 1, 2. - Select number of columns that respects the maximum number of tokens. max_nb_tokens = 8. The columns are ordered 2, 0, 1. The respective number of tokens are 4, 3, 3. Then only one column is selected 2. - In the second case, we include column 1 rather than 0 even though 0 has a better score because 0 doesn"t fit in the budget - In the third case, we include 0 because we have more budget. """ with tempfile.TemporaryDirectory() as temp_dir: vocab_file = self._create_vocab(temp_dir) selector = pruning_utils.HeuristicExactMatchTokenSelector( vocab_file=vocab_file, max_nb_tokens=max_nb_tokens, selection_level=pruning_utils.SelectionType.COLUMN, use_previous_questions=False, use_previous_answer=False) interaction = self._get_interaction() selected_columns = selector.select_tokens(interaction, interaction.questions[0]) self._assert_equals_selected_columns(expected_columns, selected_columns.selected_tokens)
def _get_token_selector(): if not FLAGS.prune_columns: return None return pruning_utils.HeuristicExactMatchTokenSelector( FLAGS.bert_vocab_file, FLAGS.max_seq_length, pruning_utils.SelectionType.COLUMN, # Only relevant for SQA where questions come in sequence use_previous_answer=True, use_previous_questions=True, )
def test_SelectHeuristicExactMatchCellsFn(self): r"""Tests HeuristicExactMatch behaviour for cell selection.""" with tempfile.TemporaryDirectory() as temp_dir: vocab_file = self._create_vocab(temp_dir) selector = pruning_utils.HeuristicExactMatchTokenSelector( vocab_file=vocab_file, max_nb_tokens=12, selection_level=pruning_utils.SelectionType.CELL, use_previous_questions=False, use_previous_answer=False) interaction = self._get_interaction() selected_cells = selector.select_tokens(interaction, interaction.questions[0]) selected_tokens = selected_cells.selected_tokens expected_tokens = set([ _Coordinates(0, 0, 0), _Coordinates(0, 0, 1), _Coordinates(1, 0, 0), _Coordinates(0, 2, 0), _Coordinates(2, 0, 0), _Coordinates(1, 2, 0), _Coordinates(2, 2, 0) ]) self.assertEqual(expected_tokens, selected_tokens) expected_debug = """ columns { index: 0 score:0.5333333333333333 is_selected: true } columns { index: 1 score:0.2 is_selected: true } columns { index: 2 score:0.5666666666666667 is_selected: true } """ self.assertEqual( selected_cells.debug, text_format.Parse(expected_debug, table_selection_pb2.TableSelection.DebugInfo()), )