예제 #1
0
    def test_MaxTokenSelector(
        self,
        max_nb_tokens,
        expected_columns,
    ):  # pylint: disable=g-doc-args
        r"""Tests HeuristicExactMatchTokenSelection.

    - Test scoring: The columns are ordered 2, 0, 1.
    - The cleaned tokens are well created: remove punctuation.
      If not the order changes and becomes 0, 1, 2.
    - Select number of columns that respects the maximum number of tokens.
      max_nb_tokens = 8. The columns are ordered 2, 0, 1.
      The respective number of tokens are 4, 3, 3.
      Then only one column is selected 2.
    - In the second case, we include column 1 rather than 0 even though 0 has
      a better score because 0 doesn"t fit in the budget
    - In the third case, we include 0 because we have more budget.
    """
        with tempfile.TemporaryDirectory() as temp_dir:
            vocab_file = self._create_vocab(temp_dir)
            selector = pruning_utils.HeuristicExactMatchTokenSelector(
                vocab_file=vocab_file,
                max_nb_tokens=max_nb_tokens,
                selection_level=pruning_utils.SelectionType.COLUMN,
                use_previous_questions=False,
                use_previous_answer=False)
            interaction = self._get_interaction()
            selected_columns = selector.select_tokens(interaction,
                                                      interaction.questions[0])
        self._assert_equals_selected_columns(expected_columns,
                                             selected_columns.selected_tokens)
예제 #2
0
def _get_token_selector():
    if not FLAGS.prune_columns:
        return None
    return pruning_utils.HeuristicExactMatchTokenSelector(
        FLAGS.bert_vocab_file,
        FLAGS.max_seq_length,
        pruning_utils.SelectionType.COLUMN,
        # Only relevant for SQA where questions come in sequence
        use_previous_answer=True,
        use_previous_questions=True,
    )
예제 #3
0
    def test_SelectHeuristicExactMatchCellsFn(self):
        r"""Tests HeuristicExactMatch behaviour for cell selection."""
        with tempfile.TemporaryDirectory() as temp_dir:
            vocab_file = self._create_vocab(temp_dir)
            selector = pruning_utils.HeuristicExactMatchTokenSelector(
                vocab_file=vocab_file,
                max_nb_tokens=12,
                selection_level=pruning_utils.SelectionType.CELL,
                use_previous_questions=False,
                use_previous_answer=False)
            interaction = self._get_interaction()
            selected_cells = selector.select_tokens(interaction,
                                                    interaction.questions[0])
        selected_tokens = selected_cells.selected_tokens

        expected_tokens = set([
            _Coordinates(0, 0, 0),
            _Coordinates(0, 0, 1),
            _Coordinates(1, 0, 0),
            _Coordinates(0, 2, 0),
            _Coordinates(2, 0, 0),
            _Coordinates(1, 2, 0),
            _Coordinates(2, 2, 0)
        ])
        self.assertEqual(expected_tokens, selected_tokens)

        expected_debug = """
        columns {
          index: 0
          score:0.5333333333333333
          is_selected: true
        }
        columns {
          index: 1
          score:0.2
          is_selected: true
        }
        columns {
          index: 2
          score:0.5666666666666667
          is_selected: true
        }
      """
        self.assertEqual(
            selected_cells.debug,
            text_format.Parse(expected_debug,
                              table_selection_pb2.TableSelection.DebugInfo()),
        )