Exemple #1
0
    def evaluate(self, ques_rep: torch.Tensor, sampled_actions: List[str],
                 slot_dic: Dict, target_list: List,
                 world: WikiTablesVariableFreeWorld) -> Dict:
        _, _token_rnn_feat_size = ques_rep.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        id2column, column2id, column_type_dic, column_reps = self.collect_column_reps(
            world.table_context)

        actions = world.get_valid_actions()
        filtered_actions = self.filter_functions(actions)

        possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \
                        sampled_actions, filtered_actions, \
                        slot_dic, world)

        max_path, max_score = possible_paths[0]
        for candidate_path, candidate_score in possible_paths[1:]:
            if candidate_score > max_score:
                max_path = candidate_path

        lf = world.get_logical_form(max_path)
        if world._executor.evaluate_logical_form(lf, target_list):
            return True
        else:
            return False
 def test_get_valid_actions_in_world_without_date_columns(self):
     question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']]
     table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-4.table'
     table_context = TableQuestionContext.read_from_file(table_file, question_tokens)
     # The table does not have a date column.
     assert "date" not in table_context.column_types.values()
     world = WikiTablesVariableFreeWorld(table_context)
     actions = world.get_valid_actions()
     assert set(actions.keys()) == {
             "<r,<g,s>>",
             "<r,<f,<n,r>>>",
             "<r,<c,r>>",
             "<r,<g,r>>",
             "<r,<r,<f,n>>>",
             "<r,<t,<s,r>>>",
             "<n,<n,<n,d>>>",
             "<r,<f,n>>",
             "<r,r>",
             "<r,n>",
             "d",
             "n",
             "s",
             "t",
             "f",
             "r",
             "@start@",
             }
     assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't',
                                                                       'f', 'g', 'c'}
     check_productions_match(actions['s'],
                             ['[<r,<g,s>>, r, f]',
                              '[<r,<g,s>>, r, t]'])
def search(tables_directory: str,
           input_examples_file: str,
           output_file: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    with open(output_file, "w") as output_file_pointer:
        for instance_data in data:
            utterance = instance_data["question"]
            question_id = instance_data["id"]
            if utterance.startswith('"') and utterance.endswith('"'):
                utterance = utterance[1:-1]
            # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
            table_file = instance_data["table_filename"].replace("csv", "tagged")
            # pylint: disable=protected-access
            target_list = [TableQuestionContext._normalize_string(value) for value in
                           instance_data["target_values"]]
            try:
                target_value_list = evaluator.to_value_list(target_list)
            except:
                print(target_list)
                target_value_list = evaluator.to_value_list(target_list)
            tokenized_question = tokenizer.tokenize(utterance)
            table_file = f"{tables_directory}/{table_file}"
            context = TableQuestionContext.read_from_file(table_file, tokenized_question)
            world = WikiTablesVariableFreeWorld(context)
            walker = ActionSpaceWalker(world, max_path_length=max_path_length)
            correct_logical_forms = []
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                agenda = world.get_agenda()
                print(f"Agenda: {agenda}", file=output_file_pointer)
                all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                         max_num_logical_forms=10000)
            else:
                all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
            for logical_form in all_logical_forms:
                try:
                    denotation = world.execute(logical_form)
                except ExecutionError:
                    print(f"Failed to execute: {logical_form}", file=sys.stderr)
                    continue
                if isinstance(denotation, list):
                    denotation_list = [str(denotation_item) for denotation_item in denotation]
                else:
                    # For numbers and dates
                    denotation_list = [str(denotation)]
                denotation_value_list = evaluator.to_value_list(denotation_list)
                if evaluator.check_denotation(target_value_list, denotation_value_list):
                    correct_logical_forms.append(logical_form)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
Exemple #4
0
    def _walk(self, actions: Dict, prod_score_dic: Dict,
              world: WikiTablesVariableFreeWorld) -> List:
        """
        search in the action space without data selection operations like lookup.
        the operations used reflect the semantics of a question, it is more abstract and the space would be much smaller
        """
        # Buffer of NTs to expand, previous actions
        incomplete_paths = [([str(type_)], [f"{START_SYMBOL} -> {type_}"],
                             prod_score_dic[f"{START_SYMBOL} -> {type_}"])
                            for type_ in world.get_valid_starting_types()]

        _completed_paths = []
        multi_match_substitutions = world.get_multi_match_mapping()

        while incomplete_paths:
            next_paths = []
            for nonterminal_buffer, history, cur_score in incomplete_paths:
                # Taking the last non-terminal added to the buffer. We're going depth-first.
                nonterminal = nonterminal_buffer.pop()
                next_actions = []
                if nonterminal in multi_match_substitutions:
                    for current_nonterminal in [
                            nonterminal
                    ] + multi_match_substitutions[nonterminal]:
                        if current_nonterminal in actions:
                            next_actions.extend(actions[current_nonterminal])
                elif nonterminal not in actions:
                    continue
                else:
                    next_actions.extend(actions[nonterminal])
                # Iterating over all possible next actions.
                for action in next_actions:
                    if action not in [
                            "e -> mul_row_select", "r -> one_row_select"
                    ] and action in history:
                        continue
                    new_history = history + [action]
                    new_nonterminal_buffer = nonterminal_buffer[:]
                    # Since we expand the last action added to the buffer, the left child should be
                    # added after the right child.
                    for right_side_part in reversed(
                            self._get_right_side_parts(action)):
                        if types.is_nonterminal(right_side_part):
                            new_nonterminal_buffer.append(right_side_part)
                    new_prod_score = prod_score_dic[action] + cur_score
                    next_paths.append(
                        (new_nonterminal_buffer, new_history, new_prod_score))
            incomplete_paths = []
            for nonterminal_buffer, path, score in next_paths:
                # An empty buffer means that we've completed this path.
                if not nonterminal_buffer:
                    # if path only has two operations, it is start->string:
                    if len(path) > 2:
                        _completed_paths.append((path, score))
                elif len(path) < self._max_path_length:
                    incomplete_paths.append((nonterminal_buffer, path, score))
        return _completed_paths
 def setUp(self):
     super().setUp()
     question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']]
     self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged'
     self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens)
     self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context)
     usl_league_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', 'with', 'usl',
                                             'a', 'league', '?']]
     self.world_with_usl_a_league = self._get_world_with_question_tokens(usl_league_tokens)
 def setUp(self):
     super().setUp()
     question_tokens = [
         Token(x)
         for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']
     ]
     self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
     self.table_kg = TableQuestionKnowledgeGraph.read_from_file(
         self.table_file, question_tokens)
     self.world = WikiTablesVariableFreeWorld(self.table_kg)
Exemple #7
0
def search(tables_directory: str, input_examples_file: str, output_path: str,
           max_path_length: int, max_num_logical_forms: int, use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [
        wikitables_util.parse_example_line(example_line)
        for example_line in open(input_examples_file)
    ]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file,
                                                      tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda,
                max_num_logical_forms=10000,
                allow_partial_match=True)
        else:
            all_logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz",
                           "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
def search(tables_directory: str,
           input_examples_file: str,
           output_path: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                     max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemple #9
0
    def forward(self, token_reps: List[torch.Tensor],
                world: WikiTablesVariableFreeWorld) -> Dict:
        actions = world.get_valid_actions()
        actions = self._filter_abstract(actions)

        prod_score_dic = self._score_prod(token_reps, actions, world)
        # sketch_lfs = self.get_all_skethch_lf(actions, prod_score_dic, world)
        sketches = self.get_all_sketches(actions, prod_score_dic, world)
        logger.info("%s skethces generated", len(sketches))

        score_list = []
        for actions, score in sketches:
            score_list.append(score)
        score_vec = torch.stack(score_list, 0)
        lf_prob = F.softmax(score_vec, dim=0)
        m = Categorical(lf_prob)
        lf_sample_t = m.sample()
        lf_sample_idx = lf_sample_t.item()

        sampled_lf_actions, sampled_score = sketches[lf_sample_idx]
        sampled_log_probs = m.log_prob(lf_sample_t)

        slot_rep = self._gen_slot_rep(prod_score_dic, sampled_lf_actions)

        return (sampled_lf_actions, sampled_log_probs, slot_rep)
Exemple #10
0
    def forward(self, ques_rep: torch.Tensor, sampled_actions: List[str],
                slot_dic: Dict, target_list: List,
                world: WikiTablesVariableFreeWorld) -> Dict:
        """
        It takes in a sampled path and finish the selection part
        based on alignments to the question, table and fileter/same_as function.

        Operations for selecting one row: filter_eq, filter_in
        Operations for selecting multiple rows: all filters and all_rows
        """
        _, _token_rnn_feat_size = ques_rep.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        id2column, column2id, column_type_dic, column_reps = self.collect_column_reps(
            world.table_context)

        actions = world.get_valid_actions()
        filtered_actions = self.filter_functions(actions)

        possible_paths = self.get_all_sequences(ques_rep, column2id, column_reps, \
                        sampled_actions, filtered_actions, \
                        slot_dic, world)

        correct_lf = []
        candidate_scores = []
        gold_ids = []
        for candidate_path, candidate_score in possible_paths:
            lf = world.get_logical_form(candidate_path)
            candidate_scores.append(candidate_score)
            if world._executor.evaluate_logical_form(lf, target_list):
                correct_lf.append(lf)
                gold_ids.append(1)
            else:
                gold_ids.append(0)

        gold_id_v = torch.FloatTensor(gold_ids)
        if torch.sum(gold_id_v) == 0:
            return 0
        else:
            score_v = torch.stack(candidate_scores, 0)
            score_prob = F.softmax(score_v, 0)
            reward_v = gold_id_v * score_prob
            return torch.sum(reward_v, 0)
Exemple #11
0
    def _walk(self, non_terminal: str, actions: List, prod_score_dic: Dict,
              world: WikiTablesVariableFreeWorld) -> None:
        incomplete_paths = [([non_terminal], [f"#S# -> {non_terminal}"], None)]

        _completed_paths = []
        multi_match_substitutions = world.get_multi_match_mapping()

        while incomplete_paths:
            next_paths = []
            for nonterminal_buffer, history, cur_score in incomplete_paths:
                # Taking the last non-terminal added to the buffer. We're going depth-first.
                nonterminal = nonterminal_buffer.pop()
                next_actions = []
                if nonterminal in multi_match_substitutions:
                    for current_nonterminal in [
                            nonterminal
                    ] + multi_match_substitutions[nonterminal]:
                        if current_nonterminal in actions:
                            next_actions.extend(actions[current_nonterminal])
                elif nonterminal not in actions:
                    continue
                else:
                    next_actions.extend(actions[nonterminal])
                # Iterating over all possible next actions.
                for action in next_actions:
                    if action in history: continue
                    new_history = history + [action]
                    new_nonterminal_buffer = nonterminal_buffer[:]
                    # Since we expand the last action added to the buffer, the left child should be
                    # added after the right child.
                    for right_side_part in reversed(
                            self._get_right_side_parts(action)):
                        if types.is_nonterminal(right_side_part):
                            new_nonterminal_buffer.append(right_side_part)
                    if cur_score is None:
                        new_prod_score = prod_score_dic[action]
                    else:
                        new_prod_score = prod_score_dic[action] + cur_score
                    next_paths.append(
                        (new_nonterminal_buffer, new_history, new_prod_score))
            incomplete_paths = []
            for nonterminal_buffer, path, score in next_paths:
                # An empty buffer means that we've completed this path.
                if not nonterminal_buffer:
                    _completed_paths.append((path, score))
                elif len(path) < 6:  #TODO: set this
                    incomplete_paths.append((nonterminal_buffer, path, score))

        strip_path = []
        for path, score in _completed_paths:
            strip_path.append((path[1:], score))  # the first node is faked
        return strip_path
Exemple #12
0
    def forward_enumerate(self, token_reps: List[torch.Tensor],
                          world: WikiTablesVariableFreeWorld) -> Dict:
        actions = world.get_valid_actions()
        actions = self._filter_abstract(actions)

        prod_score_dic = self._score_prod(token_reps, actions, world)
        # sketch_lfs = self.get_all_skethch_lf(actions, prod_score_dic, world)
        sketches = self.get_all_sketches(actions, prod_score_dic, world)
        logger.info("%s skethces generated", len(sketches))

        score_list = []
        for actions, score in sketches:
            score_list.append(score)
        score_vec = torch.stack(score_list, 0)
        lf_prob = F.softmax(score_vec, dim=0)

        for i, (lf_actions, lf_score) in enumerate(sketches):
            slot_rep = self._gen_slot_rep(prod_score_dic, lf_actions)
            yield (lf_actions, torch.log(lf_prob[i]), slot_rep)
Exemple #13
0
    def _score_prod(self, token_reps: torch.Tensor, actions: List[str],
                    world: WikiTablesVariableFreeWorld) -> Dict:
        """
        produce scores for each production rule
        return dict that has a scalar score for each production rule 
        """
        sent_len, _token_rnn_feat_size = token_reps.size()
        assert self.token_rnn_feat_size == _token_rnn_feat_size
        action_list = []
        for k, v in actions.items():
            action_list += v
        action_list += [
            f"{START_SYMBOL} -> {type_}"
            for type_ in world.get_valid_starting_types()
        ]
        action_list = list(set(action_list))
        prod_num = len(action_list)

        prod_id_list = [self.prod2id[prod] for prod in action_list]
        prod_id_tensor = torch.LongTensor(prod_id_list)
        prod_id_vec = self.prod_embed(
            prod_id_tensor)  # prod_num * prod_embed_size
        token_mat = token_reps.unsqueeze(0).expand(prod_num, sent_len,
                                                   self.token_rnn_feat_size)

        att_scores = self.att(prod_id_vec, token_mat)  # prod_num * sent_len
        att_rep_vec = torch.mm(att_scores,
                               token_reps)  # prod_num * token_embed
        feat_vec = torch.cat([prod_id_vec, att_rep_vec], 1)
        hiddden_vec = F.relu(self.hidden_func(feat_vec))
        score_vec = self.score_func(hiddden_vec)  # prod_num * 1
        score_vec = score_vec.squeeze(1)  # prod_num

        prod_score_dic = dict()
        for i, prod in enumerate(action_list):
            prod_score_dic[prod] = score_vec[i]
        return prod_score_dic
Exemple #14
0
    def predict(self, token_reps: List[torch.Tensor],
                world: WikiTablesVariableFreeWorld) -> Dict:
        actions = world.get_valid_actions()
        actions = self._filter_abstract(actions)

        prod_score_dic = self._score_prod(token_reps, actions, world)
        # sketch_lfs = self.get_all_skethch_lf(actions, prod_score_dic, world)
        sketches = self.get_all_sketches(actions, prod_score_dic, world)
        logger.info("%s skethces generated", len(sketches))

        score_list = []
        for actions, score in sketches:
            score_list.append(score)
        score_vec = torch.stack(score_list, 0)
        lf_prob = F.softmax(score_vec, dim=0)

        max_v, max_id = torch.max(lf_prob, dim=0)

        max_lf_actions, max_score = sketches[max_id]
        max_log_probs = torch.log(lf_prob[max_id])

        slot_rep = self._gen_slot_rep(prod_score_dic, max_lf_actions)

        return (max_lf_actions, max_log_probs, slot_rep)
    def _create_grammar_state(self, world: WikiTablesVariableFreeWorld,
                              possible_actions: List[ProductionRuleArray],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesVariableFreeWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions if any.
            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0)
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                if self._add_action_bias:
                    global_action_biases = self._action_biases(
                        global_action_tensor)
                    global_input_embeddings = torch.cat(
                        [global_input_embeddings, global_action_biases],
                        dim=-1)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))
        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               type_declaration.is_nonterminal)
Exemple #16
0
 def get_all_skethch_lf(self, actions: Dict, prod_score_dic: Dict,
                        world: WikiTablesVariableFreeWorld) -> List:
     paths = self.get_all_sketches(actions, prod_score_dic, world)
     logical_forms = [(world.get_logical_form(path), score)
                      for (path, score) in paths]
     return logical_forms
 def _get_world_with_question_tokens(
         self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
     table_kg = TableQuestionKnowledgeGraph.read_from_file(
         self.table_file, tokens)
     world = WikiTablesVariableFreeWorld(table_kg)
     return world
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']
        ]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tsv'
        self.table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, question_tokens)
        self.world = WikiTablesVariableFreeWorld(self.table_kg)

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        valid_actions = self.world.get_valid_actions()
        assert set(valid_actions.keys()) == {
            "<r,<l,s>>",
            "<r,<n,<l,r>>>",
            "<r,<l,r>>",
            "<r,<r,<l,n>>>",
            "<r,<s,<l,r>>>",
            "<n,<n,<n,d>>>",
            "<r,<d,<l,r>>>",
            "<r,<l,n>>",
            "<r,r>",
            "<r,n>",
            "d",
            "n",
            "s",
            "l",
            "r",
            "@start@",
        }

        check_productions_match(valid_actions['<r,<l,s>>'], ['mode', 'select'])

        check_productions_match(valid_actions['<r,<n,<l,r>>>'], [
            'filter_number_equals', 'filter_number_greater',
            'filter_number_greater_equals', 'filter_number_lesser',
            'filter_number_lesser_equals', 'filter_number_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,r>>'],
                                ['argmax', 'argmin', 'same_as'])

        check_productions_match(valid_actions['<r,<r,<l,n>>>'], ['diff'])

        check_productions_match(valid_actions['<r,<s,<l,r>>>'],
                                ['filter_in', 'filter_not_in'])

        check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date'])

        check_productions_match(valid_actions['<r,<d,<l,r>>>'], [
            'filter_date_equals', 'filter_date_greater',
            'filter_date_greater_equals', 'filter_date_lesser',
            'filter_date_lesser_equals', 'filter_date_not_equals'
        ])

        check_productions_match(valid_actions['<r,<l,n>>'],
                                ['average', 'max', 'min', 'sum'])

        check_productions_match(valid_actions['<r,r>'],
                                ['first', 'last', 'next', 'previous'])

        check_productions_match(valid_actions['<r,n>'], ['count'])

        # These are the columns in table, and are instance specific.
        check_productions_match(valid_actions['l'], [
            'fb:row.row.year', 'fb:row.row.league',
            'fb:row.row.avg_attendance', 'fb:row.row.division',
            'fb:row.row.regular_season', 'fb:row.row.playoffs',
            'fb:row.row.open_cup'
        ])

        check_productions_match(valid_actions['@start@'], ['d', 'n', 's'])

        # We merged cells and parts in SEMPRE to strings in this grammar.
        check_productions_match(valid_actions['s'], [
            'fb:cell.2', 'fb:cell.2001', 'fb:cell.2005', 'fb:cell.4th_round',
            'fb:cell.4th_western', 'fb:cell.5th', 'fb:cell.6_028',
            'fb:cell.7_169', 'fb:cell.did_not_qualify',
            'fb:cell.quarterfinals', 'fb:cell.usl_a_league',
            'fb:cell.usl_first_division', 'fb:part.4th', 'fb:part.western',
            'fb:part.5th', '[<r,<l,s>>, r, l]'
        ])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]'])

        check_productions_match(valid_actions['n'], [
            '-1', '0', '1', '2013', '[<r,<l,n>>, r, l]',
            '[<r,<r,<l,n>>>, r, r, l]', '[<r,n>, r]'
        ])

        check_productions_match(valid_actions['r'], [
            'all_rows', '[<r,<d,<l,r>>>, r, d, l]', '[<r,<l,r>>, r, l]',
            '[<r,<n,<l,r>>>, r, n, l]', '[<r,<s,<l,r>>>, r, s, l]',
            '[<r,r>, r]'
        ])

    def test_world_processes_logical_forms_correctly(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F30(R,string:usl_a_league,C2),C6)"

    def test_world_gets_correct_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        expected_sequence = [
            '@start@ -> s', 's -> [<r,<l,s>>, r, l]', '<r,<l,s>> -> select',
            'r -> [<r,<s,<l,r>>>, r, s, l]', '<r,<s,<l,r>>> -> filter_in',
            'r -> all_rows', 's -> fb:cell.usl_a_league',
            'l -> fb:row.row.league', 'l -> fb:row.row.year'
        ]
        assert self.world.get_action_sequence(expression) == expected_sequence

    def test_world_gets_logical_form_from_actions(self):
        logical_form = "(select (filter_in all_rows fb:cell.usl_a_league fb:row.row.league) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        action_sequence = self.world.get_action_sequence(expression)
        reconstructed_logical_form = self.world.get_logical_form(
            action_sequence)
        assert logical_form == reconstructed_logical_form

    def test_world_processes_logical_forms_with_number_correctly(self):
        logical_form = "(select (filter_number_greater all_rows 2013 fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F10(R,num:2013,C6),C6)"

    def test_world_processes_logical_forms_with_date_correctly(self):
        logical_form = "(select (filter_date_greater all_rows (date 2013 -1 -1) fb:row.row.year) fb:row.row.year)"
        expression = self.world.parse_logical_form(logical_form)
        # Cells (and parts) get mapped to strings.
        assert str(expression) == "S0(F20(R,T0(num:2013,num:~1,num:~1),C6),C6)"

    def _get_world_with_question_tokens(
            self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
        table_kg = TableQuestionKnowledgeGraph.read_from_file(
            self.table_file, tokens)
        world = WikiTablesVariableFreeWorld(table_kg)
        return world

    def test_get_agenda(self):
        tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'n -> 2000', 'l -> fb:row.row.year', '<r,<l,r>> -> argmax'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'difference', 'in', 'attendance',
                'between', 'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains strings here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            '<r,<r,<l,n>>> -> diff'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # The agenda contains cells here instead of numbers because 2001 and 2005 actually link to
        # entities in the table whereas 2000 (in the previous case) does not.
        assert set(world.get_agenda()) == {
            's -> fb:cell.2001', 's -> fb:cell.2005', 'l -> fb:row.row.year',
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> sum'
        }
        tokens = [
            Token(x) for x in
            ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,r>> -> argmin'
        }
        tokens = [
            Token(x)
            for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'l -> fb:row.row.avg_attendance', '<r,<l,n>> -> min'
        }
Exemple #19
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[List[str]],
            target_values: List[str],
            offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as
        TSV files pre-tagged using CoreNLP, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input CoreNLP processed lines.
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        target_values_field = MetadataField(target_values)
        world = WikiTablesVariableFreeWorld(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'target_values': target_values_field
        }

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemple #20
0
 def _get_world_with_question_tokens(
         self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
     table_context = TableQuestionContext.read_from_file(
         self.table_file, tokens)
     world = WikiTablesVariableFreeWorld(table_context)
     return world
Exemple #21
0
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']
        ]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged'
        self.table_context = TableQuestionContext.read_from_file(
            self.table_file, question_tokens)
        self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context)
        usl_league_tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'last', 'year', 'with', 'usl', 'a',
                'league', '?'
            ]
        ]
        self.world_with_usl_a_league = self._get_world_with_question_tokens(
            usl_league_tokens)

    def _get_world_with_question_tokens(
            self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
        table_context = TableQuestionContext.read_from_file(
            self.table_file, tokens)
        world = WikiTablesVariableFreeWorld(table_context)
        return world

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.

        valid_actions = self.world_with_2013.get_valid_actions()
        assert set(valid_actions.keys()) == {
            "<r,<g,s>>",
            "<r,<f,<n,r>>>",
            "<r,<c,r>>",
            "<r,<g,r>>",
            "<r,<r,<f,n>>>",
            "<r,<t,<s,r>>>",
            "<n,<n,<n,d>>>",
            "<r,<m,<d,r>>>",
            "<r,<f,n>>",
            "<r,r>",
            "<r,n>",
            "d",
            "n",
            "s",
            "m",
            "t",
            "f",
            "r",
            "@start@",
        }

        check_productions_match(valid_actions['<r,<g,s>>'], ['mode', 'select'])

        check_productions_match(valid_actions['<r,<f,<n,r>>>'], [
            'filter_number_equals', 'filter_number_greater',
            'filter_number_greater_equals', 'filter_number_lesser',
            'filter_number_lesser_equals', 'filter_number_not_equals'
        ])

        check_productions_match(valid_actions['<r,<c,r>>'],
                                ['argmax', 'argmin'])

        check_productions_match(valid_actions['<r,<g,r>>'], ['same_as'])

        check_productions_match(valid_actions['<r,<r,<f,n>>>'], ['diff'])

        check_productions_match(valid_actions['<r,<t,<s,r>>>'],
                                ['filter_in', 'filter_not_in'])

        check_productions_match(valid_actions['<n,<n,<n,d>>>'], ['date'])

        check_productions_match(valid_actions['<r,<m,<d,r>>>'], [
            'filter_date_equals', 'filter_date_greater',
            'filter_date_greater_equals', 'filter_date_lesser',
            'filter_date_lesser_equals', 'filter_date_not_equals'
        ])

        check_productions_match(valid_actions['<r,<f,n>>'],
                                ['average', 'max', 'min', 'sum'])

        check_productions_match(valid_actions['<r,r>'],
                                ['first', 'last', 'next', 'previous'])

        check_productions_match(valid_actions['<r,n>'], ['count'])

        # These are the columns in table, and are instance specific.
        check_productions_match(valid_actions['m'], ['date_column:year'])

        check_productions_match(
            valid_actions['f'],
            ['number_column:avg_attendance', 'number_column:division'])

        check_productions_match(valid_actions['t'], [
            'string_column:league', 'string_column:playoffs',
            'string_column:open_cup', 'string_column:regular_season'
        ])

        check_productions_match(valid_actions['@start@'], ['d', 'n', 's'])

        # The question does not produce any strings. It produces just a number.
        check_productions_match(valid_actions['s'], ['[<r,<g,s>>, r, g]'])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]'])

        check_productions_match(valid_actions['n'], [
            '2013', '-1', '[<r,<f,n>>, r, f]', '[<r,<r,<f,n>>>, r, r, f]',
            '[<r,n>, r]'
        ])

        check_productions_match(valid_actions['r'], [
            'all_rows', '[<r,<m,<d,r>>>, r, m, d]', '[<r,<g,r>>, r, g]',
            '[<r,<c,r>>, r, c]', '[<r,<f,<n,r>>>, r, f, n]',
            '[<r,<t,<s,r>>>, r, t, s]', '[<r,r>, r]'
        ])

    def test_parsing_logical_form_with_string_not_in_question_fails(self):
        logical_form_with_usl_a_league = """(select (filter_in all_rows string_column:league usl_a_league)
                                             date_column:year)"""
        logical_form_with_2013 = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1))
                                     date_column:year)"""
        with self.assertRaises(ParsingError):
            self.world_with_2013.parse_logical_form(
                logical_form_with_usl_a_league)
            self.world_with_usl_a_league.parse_logical_form(
                logical_form_with_2013)

    def test_world_processes_logical_forms_correctly(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(
            logical_form)
        f = types.name_mapper.get_alias
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(
            expression
        ) == f"{f('select')}({f('filter_in')}({f('all_rows')},C2,string:usl_a_league),C0)"

    def test_world_gets_correct_actions(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(
            logical_form)
        expected_sequence = [
            '@start@ -> s', 's -> [<r,<g,s>>, r, m]', '<r,<g,s>> -> select',
            'r -> [<r,<t,<s,r>>>, r, t, s]', '<r,<t,<s,r>>> -> filter_in',
            'r -> all_rows', 't -> string_column:league',
            's -> string:usl_a_league', 'm -> date_column:year'
        ]
        assert self.world_with_usl_a_league.get_action_sequence(
            expression) == expected_sequence

    def test_world_gets_logical_form_from_actions(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(
            logical_form)
        action_sequence = self.world_with_usl_a_league.get_action_sequence(
            expression)
        reconstructed_logical_form = self.world_with_usl_a_league.get_logical_form(
            action_sequence)
        assert logical_form == reconstructed_logical_form

    def test_world_processes_logical_forms_with_number_correctly(self):
        tokens = [
            Token(x) for x in [
                'when', 'was', 'the', 'attendance', 'higher', 'than', '3000',
                '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance 3000)
                           date_column:year)"""
        expression = world.parse_logical_form(logical_form)
        f = types.name_mapper.get_alias
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(
            expression
        ) == f"{f('select')}({f('filter_number_greater')}({f('all_rows')},C6,num:3000),C0)"

    def test_world_processes_logical_forms_with_date_correctly(self):
        logical_form = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1))
                           date_column:year)"""
        expression = self.world_with_2013.parse_logical_form(logical_form)
        f = types.name_mapper.get_alias
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(expression) == \
        f"{f('select')}({f('filter_date_greater')}({f('all_rows')},C0,{f('date')}(num:2013,num:~1,num:~1)),C0)"

    def test_get_agenda(self):
        tokens = [
            Token(x)
            for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'n -> 2000', '<r,r> -> last', 'm -> date_column:year'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'difference', 'in', 'attendance',
                'between', 'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        # "year" column does not match because "years" occurs in the question.
        assert set(world.get_agenda()) == {
            'n -> 2001', 'n -> 2005', '<r,<r,<f,n>>> -> diff'
        }
        tokens = [
            Token(x) for x in [
                'what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                'years', '2001', 'and', '2005', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            'n -> 2001', 'n -> 2005', '<r,<f,n>> -> sum',
            'f -> number_column:avg_attendance'
        }
        tokens = [
            Token(x) for x in
            ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            '<r,<c,r>> -> argmin', 'f -> number_column:avg_attendance'
        }
        tokens = [
            Token(x)
            for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            '<r,<f,n>> -> min', 'f -> number_column:avg_attendance'
        }
        tokens = [
            Token(x)
            for x in ['when', 'did', 'the', 'team', 'not', 'qualify', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'s -> string:qualify'}
        tokens = [
            Token(x) for x in [
                'when', 'was', 'the', 'avg.', 'attendance', 'at', 'least',
                '7000', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            '<r,<f,<n,r>>> -> filter_number_greater_equals',
            'f -> number_column:avg_attendance', 'n -> 7000'
        }
        tokens = [
            Token(x) for x in [
                'when', 'was', 'the', 'avg.', 'attendance', 'more', 'than',
                '7000', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            '<r,<f,<n,r>>> -> filter_number_greater',
            'f -> number_column:avg_attendance', 'n -> 7000'
        }
        tokens = [
            Token(x) for x in [
                'when', 'was', 'the', 'avg.', 'attendance', 'at', 'most',
                '7000', '?'
            ]
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {
            '<r,<f,<n,r>>> -> filter_number_lesser_equals',
            'f -> number_column:avg_attendance', 'n -> 7000'
        }
        tokens = [Token(x) for x in ['what', 'was', 'the', 'top', 'year', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(
            world.get_agenda()) == {'<r,r> -> first', 'm -> date_column:year'}
        tokens = [
            Token(x) for x in
            ['what', 'was', 'the', 'year', 'in', 'the', 'bottom', 'row', '?']
        ]
        world = self._get_world_with_question_tokens(tokens)
        assert set(
            world.get_agenda()) == {'<r,r> -> last', 'm -> date_column:year'}
class TestWikiTablesVariableFreeWorld(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2013', '?']]
        self.table_file = self.FIXTURES_ROOT / 'data' / 'wikitables' / 'sample_table.tagged'
        self.table_context = TableQuestionContext.read_from_file(self.table_file, question_tokens)
        self.world_with_2013 = WikiTablesVariableFreeWorld(self.table_context)
        usl_league_tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', 'with', 'usl',
                                                'a', 'league', '?']]
        self.world_with_usl_a_league = self._get_world_with_question_tokens(usl_league_tokens)

    def _get_world_with_question_tokens(self, tokens: List[Token]) -> WikiTablesVariableFreeWorld:
        table_context = TableQuestionContext.read_from_file(self.table_file, tokens)
        world = WikiTablesVariableFreeWorld(table_context)
        return world

    def test_get_valid_actions_returns_correct_set(self):
        # This test is long, but worth it.  These are all of the valid actions in the grammar, and
        # we want to be sure they are what we expect.
        valid_actions = self.world_with_2013.get_valid_actions()
        assert set(valid_actions.keys()) == {
                "<r,<g,s>>",
                "<r,<f,<n,r>>>",
                "<r,<c,r>>",
                "<r,<g,r>>",
                "<r,<r,<f,n>>>",
                "<r,<t,<s,r>>>",
                "<n,<n,<n,d>>>",
                "<r,<m,<d,r>>>",
                "<r,<f,n>>",
                "<r,r>",
                "<r,n>",
                "d",
                "n",
                "s",
                "m",
                "t",
                "f",
                "r",
                "@start@",
                }

        check_productions_match(valid_actions['<r,<g,s>>'],
                                ['mode', 'select'])

        check_productions_match(valid_actions['<r,<f,<n,r>>>'],
                                ['filter_number_equals', 'filter_number_greater',
                                 'filter_number_greater_equals', 'filter_number_lesser',
                                 'filter_number_lesser_equals', 'filter_number_not_equals'])

        check_productions_match(valid_actions['<r,<c,r>>'],
                                ['argmax', 'argmin'])

        check_productions_match(valid_actions['<r,<g,r>>'],
                                ['same_as'])

        check_productions_match(valid_actions['<r,<r,<f,n>>>'],
                                ['diff'])

        check_productions_match(valid_actions['<r,<t,<s,r>>>'],
                                ['filter_in', 'filter_not_in'])

        check_productions_match(valid_actions['<n,<n,<n,d>>>'],
                                ['date'])

        check_productions_match(valid_actions['<r,<m,<d,r>>>'],
                                ['filter_date_equals', 'filter_date_greater',
                                 'filter_date_greater_equals', 'filter_date_lesser',
                                 'filter_date_lesser_equals', 'filter_date_not_equals'])

        check_productions_match(valid_actions['<r,<f,n>>'],
                                ['average', 'max', 'min', 'sum'])

        check_productions_match(valid_actions['<r,r>'],
                                ['first', 'last', 'next', 'previous'])

        check_productions_match(valid_actions['<r,n>'],
                                ['count'])

        # These are the columns in table, and are instance specific.
        check_productions_match(valid_actions['m'],
                                ['date_column:year'])

        check_productions_match(valid_actions['f'],
                                ['number_column:avg_attendance',
                                 'number_column:division'])

        check_productions_match(valid_actions['t'],
                                ['string_column:league',
                                 'string_column:playoffs',
                                 'string_column:open_cup',
                                 'string_column:regular_season'])

        check_productions_match(valid_actions['@start@'],
                                ['d', 'n', 's'])

        # The question does not produce any strings. It produces just a number.
        check_productions_match(valid_actions['s'],
                                ['[<r,<g,s>>, r, m]',
                                 '[<r,<g,s>>, r, f]',
                                 '[<r,<g,s>>, r, t]'])

        check_productions_match(valid_actions['d'],
                                ['[<n,<n,<n,d>>>, n, n, n]'])

        check_productions_match(valid_actions['n'],
                                ['2013',
                                 '-1',
                                 '[<r,<f,n>>, r, f]',
                                 '[<r,<r,<f,n>>>, r, r, f]',
                                 '[<r,n>, r]'])

        check_productions_match(valid_actions['r'],
                                ['all_rows',
                                 '[<r,<m,<d,r>>>, r, m, d]',
                                 '[<r,<g,r>>, r, m]',
                                 '[<r,<g,r>>, r, f]',
                                 '[<r,<g,r>>, r, t]',
                                 '[<r,<c,r>>, r, m]',
                                 '[<r,<c,r>>, r, f]',
                                 '[<r,<f,<n,r>>>, r, f, n]',
                                 '[<r,<t,<s,r>>>, r, t, s]',
                                 '[<r,r>, r]'])

    def test_get_valid_actions_in_world_without_number_columns(self):
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']]
        table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-6.table'
        table_context = TableQuestionContext.read_from_file(table_file, question_tokens)
        # The table does not have a number column.
        assert "number" not in table_context.column_types.values()
        world = WikiTablesVariableFreeWorld(table_context)
        actions = world.get_valid_actions()
        assert set(actions.keys()) == {
                "<r,<g,s>>",
                "<r,<c,r>>",
                "<r,<g,r>>",
                "<r,<t,<s,r>>>",
                "<n,<n,<n,d>>>",
                "<r,<m,<d,r>>>",
                "<r,r>",
                "<r,n>",
                "d",
                "n",
                "s",
                "m",
                "t",
                "r",
                "@start@",
                }
        assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't',
                                                                          'm', 'g', 'c'}
        check_productions_match(actions['s'],
                                ['[<r,<g,s>>, r, m]',
                                 '[<r,<g,s>>, r, t]'])

    def test_get_valid_actions_in_world_without_date_columns(self):
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']]
        table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-4.table'
        table_context = TableQuestionContext.read_from_file(table_file, question_tokens)
        # The table does not have a date column.
        assert "date" not in table_context.column_types.values()
        world = WikiTablesVariableFreeWorld(table_context)
        actions = world.get_valid_actions()
        assert set(actions.keys()) == {
                "<r,<g,s>>",
                "<r,<f,<n,r>>>",
                "<r,<c,r>>",
                "<r,<g,r>>",
                "<r,<r,<f,n>>>",
                "<r,<t,<s,r>>>",
                "<n,<n,<n,d>>>",
                "<r,<f,n>>",
                "<r,r>",
                "<r,n>",
                "d",
                "n",
                "s",
                "t",
                "f",
                "r",
                "@start@",
                }
        assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't',
                                                                          'f', 'g', 'c'}
        check_productions_match(actions['s'],
                                ['[<r,<g,s>>, r, f]',
                                 '[<r,<g,s>>, r, t]'])

    def test_get_valid_actions_in_world_without_comparable_columns(self):
        question_tokens = [Token(x) for x in ['what', 'was', 'the', 'first', 'title', '?']]
        table_file = self.FIXTURES_ROOT / 'data' / 'corenlp_processed_tables' / 'TEST-1.table'
        table_context = TableQuestionContext.read_from_file(table_file, question_tokens)
        # The table does not have date or number columns.
        assert "date" not in table_context.column_types.values()
        assert "number" not in table_context.column_types.values()
        world = WikiTablesVariableFreeWorld(table_context)
        actions = world.get_valid_actions()
        assert set(actions.keys()) == {
                "<r,<g,s>>",
                "<r,<g,r>>",
                "<r,<t,<s,r>>>",
                "<n,<n,<n,d>>>",
                "<r,r>",
                "<r,n>",
                "d",
                "n",
                "s",
                "t",
                "r",
                "@start@",
                }
        assert set([str(type_) for type_ in world.get_basic_types()]) == {'n', 'd', 's', 'r', 't', 'g'}

    def test_parsing_logical_form_with_string_not_in_question_fails(self):
        logical_form_with_usl_a_league = """(select (filter_in all_rows string_column:league usl_a_league)
                                             date_column:year)"""
        logical_form_with_2013 = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1))
                                     date_column:year)"""
        with self.assertRaises(ParsingError):
            self.world_with_2013.parse_logical_form(logical_form_with_usl_a_league)
            self.world_with_usl_a_league.parse_logical_form(logical_form_with_2013)

    @staticmethod
    def _get_alias(types_, name) -> str:
        if name in types_.generic_name_mapper.common_name_mapping:
            return types_.generic_name_mapper.get_alias(name)
        elif name in types_.string_column_name_mapper.common_name_mapping:
            return types_.string_column_name_mapper.get_alias(name)
        elif name in types_.number_column_name_mapper.common_name_mapping:
            return types_.number_column_name_mapper.get_alias(name)
        elif name in types_.date_column_name_mapper.common_name_mapping:
            return types_.date_column_name_mapper.get_alias(name)
        else:
            return types_.comparable_column_name_mapper.get_alias(name)

    def test_world_processes_logical_forms_correctly(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(logical_form)
        f = partial(self._get_alias, types)
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(expression) == f"{f('select')}({f('filter_in')}({f('all_rows')},C2,string:usl_a_league),C0)"

    def test_world_gets_correct_actions(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(logical_form)
        expected_sequence = ['@start@ -> s', 's -> [<r,<g,s>>, r, m]', '<r,<g,s>> -> select',
                             'r -> [<r,<t,<s,r>>>, r, t, s]', '<r,<t,<s,r>>> -> filter_in',
                             'r -> all_rows', 't -> string_column:league', 's -> string:usl_a_league',
                             'm -> date_column:year']
        assert self.world_with_usl_a_league.get_action_sequence(expression) == expected_sequence

    def test_world_gets_logical_form_from_actions(self):
        logical_form = "(select (filter_in all_rows string_column:league string:usl_a_league) date_column:year)"
        expression = self.world_with_usl_a_league.parse_logical_form(logical_form)
        action_sequence = self.world_with_usl_a_league.get_action_sequence(expression)
        reconstructed_logical_form = self.world_with_usl_a_league.get_logical_form(action_sequence)
        assert logical_form == reconstructed_logical_form

    def test_world_processes_logical_forms_with_number_correctly(self):
        tokens = [Token(x) for x in ['when', 'was', 'the', 'attendance', 'higher', 'than', '3000',
                                     '?']]
        world = self._get_world_with_question_tokens(tokens)
        logical_form = """(select (filter_number_greater all_rows number_column:avg_attendance 3000)
                           date_column:year)"""
        expression = world.parse_logical_form(logical_form)
        f = partial(self._get_alias, types)
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(expression) == f"{f('select')}({f('filter_number_greater')}({f('all_rows')},C6,num:3000),C0)"

    def test_world_processes_logical_forms_with_date_correctly(self):
        logical_form = """(select (filter_date_greater all_rows date_column:year (date 2013 -1 -1))
                           date_column:year)"""
        expression = self.world_with_2013.parse_logical_form(logical_form)
        f = partial(self._get_alias, types)
        # Cells (and parts) get mapped to strings.
        # Column names are mapped in local name mapping. For the global names, we can get their
        # aliases from the name mapper.
        assert str(expression) == \
        f"{f('select')}({f('filter_date_greater')}({f('all_rows')},C0,{f('date')}(num:2013,num:~1,num:~1)),C0)"

    def test_get_agenda(self):
        tokens = [Token(x) for x in ['what', 'was', 'the', 'last', 'year', '2000', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'n -> 2000',
                                           '<r,r> -> last',
                                           'm -> date_column:year'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'difference', 'in', 'attendance',
                                     'between', 'years', '2001', 'and', '2005', '?']]
        world = self._get_world_with_question_tokens(tokens)
        # "year" column does not match because "years" occurs in the question.
        assert set(world.get_agenda()) == {'n -> 2001',
                                           'n -> 2005',
                                           '<r,<r,<f,n>>> -> diff'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'total', 'avg.', 'attendance', 'in',
                                     'years', '2001', 'and', '2005', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'n -> 2001',
                                           'n -> 2005',
                                           '<r,<f,n>> -> sum',
                                           'f -> number_column:avg_attendance'}
        tokens = [Token(x) for x in ['when', 'was', 'the', 'least', 'avg.', 'attendance', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,<c,r>> -> argmin', 'f -> number_column:avg_attendance'}
        tokens = [Token(x) for x in ['what', 'is', 'the', 'least', 'avg.', 'attendance', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,<f,n>> -> min', 'f -> number_column:avg_attendance'}
        tokens = [Token(x) for x in ['when', 'did', 'the', 'team', 'not', 'qualify', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'s -> string:qualify'}
        tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'at', 'least',
                                     '7000', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_greater_equals',
                                           'f -> number_column:avg_attendance', 'n -> 7000'}
        tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'more', 'than',
                                     '7000', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_greater',
                                           'f -> number_column:avg_attendance', 'n -> 7000'}
        tokens = [Token(x) for x in ['when', 'was', 'the', 'avg.', 'attendance', 'at', 'most',
                                     '7000', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,<f,<n,r>>> -> filter_number_lesser_equals',
                                           'f -> number_column:avg_attendance', 'n -> 7000'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'top', 'year', '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,r> -> first', 'm -> date_column:year'}
        tokens = [Token(x) for x in ['what', 'was', 'the', 'year', 'in', 'the', 'bottom', 'row',
                                     '?']]
        world = self._get_world_with_question_tokens(tokens)
        assert set(world.get_agenda()) == {'<r,r> -> last', 'm -> date_column:year'}